From 6d152577525260aa5178066170ace3afd3dc78ec Mon Sep 17 00:00:00 2001 From: "Schlimbach, Frank" Date: Fri, 15 Nov 2024 12:38:11 +0100 Subject: [PATCH] cleanup: remove dist dialect and its transforms/conversions, reducescope of ndarray and eliminate NDArrayType --- include/imex/Conversion/CMakeLists.txt | 1 - .../Conversion/DistToStandard/CMakeLists.txt | 0 .../DistToStandard/DistToStandard.h | 44 - include/imex/Conversion/Passes.h | 1 - include/imex/Conversion/Passes.td | 30 - include/imex/Dialect/CMakeLists.txt | 1 - include/imex/Dialect/Dist/CMakeLists.txt | 3 - include/imex/Dialect/Dist/IR/CMakeLists.txt | 8 - include/imex/Dialect/Dist/IR/DistOps.h | 78 - include/imex/Dialect/Dist/IR/DistOps.td | 639 ------ .../Dialect/Dist/Transforms/CMakeLists.txt | 7 - include/imex/Dialect/Dist/Transforms/Passes.h | 53 - .../imex/Dialect/Dist/Transforms/Passes.td | 53 - .../imex/Dialect/Dist/Utils/CMakeLists.txt | 1 - include/imex/Dialect/Dist/Utils/Utils.h | 337 --- .../Dialect/DistRuntime/IR/DistRuntimeOps.td | 93 - .../Dialect/DistRuntime/Transforms/Passes.h | 4 - .../Dialect/DistRuntime/Transforms/Passes.td | 15 - .../imex/Dialect/NDArray/IR/CMakeLists.txt | 6 + include/imex/Dialect/NDArray/IR/NDArrayOps.h | 32 +- include/imex/Dialect/NDArray/IR/NDArrayOps.td | 400 +--- .../imex/Dialect/NDArray/Transforms/Passes.h | 13 +- .../imex/Dialect/NDArray/Transforms/Passes.td | 46 +- include/imex/Dialect/NDArray/Utils/Utils.h | 11 +- .../imex/Dialect/Region/Transforms/Passes.td | 2 +- include/imex/InitIMEXDialects.h | 6 +- include/imex/InitIMEXPasses.h | 2 - include/imex/Utils/PassUtils.h | 7 +- lib/Conversion/CMakeLists.txt | 1 - lib/Conversion/DistToStandard/CMakeLists.txt | 13 - .../DistToStandard/DistToStandard.cpp | 1897 ----------------- .../NDArrayToLinalg/NDArrayToLinalg.cpp | 1243 +++-------- lib/Dialect/CMakeLists.txt | 1 - lib/Dialect/Dist/CMakeLists.txt | 2 - lib/Dialect/Dist/IR/CMakeLists.txt | 12 - lib/Dialect/Dist/IR/DistOps.cpp | 229 -- lib/Dialect/Dist/Transforms/CMakeLists.txt | 15 - .../Transforms/DistInferElementwiseCores.cpp | 313 --- lib/Dialect/DistRuntime/IR/CMakeLists.txt | 1 - lib/Dialect/DistRuntime/IR/CopyPermuteOp.cpp | 11 +- lib/Dialect/DistRuntime/IR/CopyReshapeOp.cpp | 7 +- lib/Dialect/DistRuntime/IR/GetHaloOp.cpp | 180 -- .../Transforms/AddCommCacheKeys.cpp | 51 - .../DistRuntime/Transforms/CMakeLists.txt | 3 - .../Transforms/DistRuntimeToIDTR.cpp | 260 +-- .../Transforms/OverlapCommAndCompute.cpp | 210 -- .../Extensions/MeshShardingExtensions.cpp | 69 +- lib/Dialect/NDArray/IR/CMakeLists.txt | 6 +- lib/Dialect/NDArray/IR/CastElemTypeOp.cpp | 86 + lib/Dialect/NDArray/IR/CastOp.cpp | 229 -- lib/Dialect/NDArray/IR/CreateOp.cpp | 4 +- lib/Dialect/NDArray/IR/DimOp.cpp | 169 -- lib/Dialect/NDArray/IR/EWBinOp.cpp | 77 - lib/Dialect/NDArray/IR/EWOp.h | 36 - lib/Dialect/NDArray/IR/EWUnyOp.cpp | 75 - lib/Dialect/NDArray/IR/InsertSliceOp.cpp | 8 +- lib/Dialect/NDArray/IR/LinSpaceOp.cpp | 4 +- lib/Dialect/NDArray/IR/NDArrayOps.cpp | 105 +- lib/Dialect/NDArray/IR/PermuteDimsOp.cpp | 127 -- lib/Dialect/NDArray/IR/SubviewOp.cpp | 25 +- .../NDArray/Transforms/AddGPURegions.cpp | 30 +- lib/Dialect/NDArray/Transforms/CMakeLists.txt | 3 +- .../Transforms/CoalesceShardOps.cpp} | 327 +-- .../NDArray/Transforms/NDArrayDist.cpp | 223 -- .../DistToStandard/BoundingBox.mlir | 46 - .../DistToStandard/DefaultPartition.mlir | 36 - .../DistToStandard/DistToStandard.mlir | 254 --- test/Conversion/DistToStandard/Subview.mlir | 125 -- test/Dialect/Dist/IR/DistOps.mlir | 50 - test/Dialect/Dist/IR/lit.local.cfg | 7 - .../Dialect/Dist/Transforms/DistCoalesce.mlir | 70 - .../Dist/Transforms/DistInferEWCores.mlir | 84 - .../NDArray/Transforms/NDArrayDist.mlir | 60 - 73 files changed, 655 insertions(+), 8022 deletions(-) delete mode 100644 include/imex/Conversion/DistToStandard/CMakeLists.txt delete mode 100644 include/imex/Conversion/DistToStandard/DistToStandard.h delete mode 100644 include/imex/Dialect/Dist/CMakeLists.txt delete mode 100644 include/imex/Dialect/Dist/IR/CMakeLists.txt delete mode 100644 include/imex/Dialect/Dist/IR/DistOps.h delete mode 100644 include/imex/Dialect/Dist/IR/DistOps.td delete mode 100644 include/imex/Dialect/Dist/Transforms/CMakeLists.txt delete mode 100644 include/imex/Dialect/Dist/Transforms/Passes.h delete mode 100644 include/imex/Dialect/Dist/Transforms/Passes.td delete mode 100644 include/imex/Dialect/Dist/Utils/CMakeLists.txt delete mode 100644 include/imex/Dialect/Dist/Utils/Utils.h delete mode 100644 lib/Conversion/DistToStandard/CMakeLists.txt delete mode 100644 lib/Conversion/DistToStandard/DistToStandard.cpp delete mode 100644 lib/Dialect/Dist/CMakeLists.txt delete mode 100644 lib/Dialect/Dist/IR/CMakeLists.txt delete mode 100644 lib/Dialect/Dist/IR/DistOps.cpp delete mode 100644 lib/Dialect/Dist/Transforms/CMakeLists.txt delete mode 100644 lib/Dialect/Dist/Transforms/DistInferElementwiseCores.cpp delete mode 100644 lib/Dialect/DistRuntime/IR/GetHaloOp.cpp delete mode 100644 lib/Dialect/DistRuntime/Transforms/AddCommCacheKeys.cpp delete mode 100644 lib/Dialect/DistRuntime/Transforms/OverlapCommAndCompute.cpp create mode 100644 lib/Dialect/NDArray/IR/CastElemTypeOp.cpp delete mode 100644 lib/Dialect/NDArray/IR/CastOp.cpp delete mode 100644 lib/Dialect/NDArray/IR/DimOp.cpp delete mode 100644 lib/Dialect/NDArray/IR/EWBinOp.cpp delete mode 100644 lib/Dialect/NDArray/IR/EWOp.h delete mode 100644 lib/Dialect/NDArray/IR/EWUnyOp.cpp delete mode 100644 lib/Dialect/NDArray/IR/PermuteDimsOp.cpp rename lib/Dialect/{Dist/Transforms/DistCoalesce.cpp => NDArray/Transforms/CoalesceShardOps.cpp} (62%) delete mode 100644 lib/Dialect/NDArray/Transforms/NDArrayDist.cpp delete mode 100644 test/Conversion/DistToStandard/BoundingBox.mlir delete mode 100644 test/Conversion/DistToStandard/DefaultPartition.mlir delete mode 100644 test/Conversion/DistToStandard/DistToStandard.mlir delete mode 100644 test/Conversion/DistToStandard/Subview.mlir delete mode 100644 test/Dialect/Dist/IR/DistOps.mlir delete mode 100644 test/Dialect/Dist/IR/lit.local.cfg delete mode 100644 test/Dialect/Dist/Transforms/DistCoalesce.mlir delete mode 100644 test/Dialect/Dist/Transforms/DistInferEWCores.mlir delete mode 100644 test/Dialect/NDArray/Transforms/NDArrayDist.mlir diff --git a/include/imex/Conversion/CMakeLists.txt b/include/imex/Conversion/CMakeLists.txt index 635dcd12f..d20327019 100644 --- a/include/imex/Conversion/CMakeLists.txt +++ b/include/imex/Conversion/CMakeLists.txt @@ -5,7 +5,6 @@ mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Conversion) add_public_tablegen_target(IMEXConversionPassIncGen) add_mlir_doc(Passes IMEXConversionPasses ./ -gen-pass-doc) -add_subdirectory(DistToStandard) add_subdirectory(DropRegions) add_subdirectory(XeTileToXeGPU) add_subdirectory(XeGPUToVC) diff --git a/include/imex/Conversion/DistToStandard/CMakeLists.txt b/include/imex/Conversion/DistToStandard/CMakeLists.txt deleted file mode 100644 index e69de29bb..000000000 diff --git a/include/imex/Conversion/DistToStandard/DistToStandard.h b/include/imex/Conversion/DistToStandard/DistToStandard.h deleted file mode 100644 index 4ea9421a6..000000000 --- a/include/imex/Conversion/DistToStandard/DistToStandard.h +++ /dev/null @@ -1,44 +0,0 @@ -//===- DistToStandard.h - DistToStandard conversion ------------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines the DistToStandard conversion, converting the Dist -/// dialect to standard dialects. -/// -//===----------------------------------------------------------------------===// - -#ifndef _DistToStandard_H_INCLUDED_ -#define _DistToStandard_H_INCLUDED_ - -#include -#include -#include - -namespace mlir { -class LLVMTypeConverter; -class MLIRContext; -class ModuleOp; -template class OperationPass; -class RewritePatternSet; -} // namespace mlir - -namespace imex { -#define GEN_PASS_DECL_CONVERTDISTTOSTANDARD -#include "imex/Conversion/Passes.h.inc" - -/// Populate the given list with patterns rewrite Dist Ops -void populateDistToStandardConversionPatterns( - ::mlir::LLVMTypeConverter &converter, ::mlir::RewritePatternSet &patterns); - -/// Create a pass to convert the Dist dialect to the Standard dialect. -std::unique_ptr<::mlir::Pass> createConvertDistToStandardPass(); - -} // namespace imex - -#endif // _DistToStandard_H_INCLUDED_ diff --git a/include/imex/Conversion/Passes.h b/include/imex/Conversion/Passes.h index 0cb9c8c81..64806e7b7 100644 --- a/include/imex/Conversion/Passes.h +++ b/include/imex/Conversion/Passes.h @@ -17,7 +17,6 @@ #include "mlir/Pass/Pass.h" -#include #include #include #include diff --git a/include/imex/Conversion/Passes.td b/include/imex/Conversion/Passes.td index 1ad939d4a..53f6f4b08 100644 --- a/include/imex/Conversion/Passes.td +++ b/include/imex/Conversion/Passes.td @@ -57,43 +57,13 @@ def ConvertNDArrayToLinalg : Pass<"convert-ndarray-to-linalg"> { }]; let constructor = "imex::createConvertNDArrayToLinalgPass()"; let dependentDialects = ["::mlir::linalg::LinalgDialect", - "::mlir::affine::AffineDialect", - "::mlir::func::FuncDialect", - "::mlir::arith::ArithDialect", "::mlir::tensor::TensorDialect", - "::mlir::tosa::TosaDialect", - "::mlir::scf::SCFDialect", "::mlir::memref::MemRefDialect", - "::mlir::shape::ShapeDialect", "::mlir::bufferization::BufferizationDialect", "::imex::region::RegionDialect"]; let options = []; } -//===----------------------------------------------------------------------===// -// DistToStandard -//===----------------------------------------------------------------------===// - -def ConvertDistToStandard: Pass<"convert-dist-to-standard"> { - let summary = "Convert from the Dist dialect to runtime calls."; - let description = [{ - Convert Dist dialect operations into standard dialect operations - by inserting calls into a distributed runtime. - - Necessary prototypes of runtime functions will be added. - }]; - let constructor = "::imex::createConvertDistToStandardPass()"; - let dependentDialects = ["::imex::ndarray::NDArrayDialect", - "::imex::distruntime::DistRuntimeDialect", - "::mlir::linalg::LinalgDialect", - "::mlir::func::FuncDialect", - "::mlir::tensor::TensorDialect", - "::mlir::memref::MemRefDialect", - "::mlir::arith::ArithDialect", - "::mlir::scf::SCFDialect", - "::mlir::bufferization::BufferizationDialect"]; - let options = []; -} //===----------------------------------------------------------------------===// // DropRegions diff --git a/include/imex/Dialect/CMakeLists.txt b/include/imex/Dialect/CMakeLists.txt index 8fd29c73f..6d91a7df7 100644 --- a/include/imex/Dialect/CMakeLists.txt +++ b/include/imex/Dialect/CMakeLists.txt @@ -1,4 +1,3 @@ -add_subdirectory(Dist) add_subdirectory(DistRuntime) add_subdirectory(NDArray) add_subdirectory(Region) diff --git a/include/imex/Dialect/Dist/CMakeLists.txt b/include/imex/Dialect/Dist/CMakeLists.txt deleted file mode 100644 index 711752059..000000000 --- a/include/imex/Dialect/Dist/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -add_subdirectory(IR) -add_subdirectory(Utils) -add_subdirectory(Transforms) diff --git a/include/imex/Dialect/Dist/IR/CMakeLists.txt b/include/imex/Dialect/Dist/IR/CMakeLists.txt deleted file mode 100644 index d281623dc..000000000 --- a/include/imex/Dialect/Dist/IR/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -add_mlir_dialect(DistOps dist) -add_mlir_doc(DistOps DistDialect Dialects/ -gen-dialect-doc) - -set(LLVM_TARGET_DEFINITIONS DistOps.td) -mlir_tablegen(DistOpsAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=dist) -mlir_tablegen(DistOpsAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=dist) -add_public_tablegen_target(MLIRDistIncGen) -add_dependencies(mlir-headers MLIRDistIncGen) diff --git a/include/imex/Dialect/Dist/IR/DistOps.h b/include/imex/Dialect/Dist/IR/DistOps.h deleted file mode 100644 index 50adc7b3a..000000000 --- a/include/imex/Dialect/Dist/IR/DistOps.h +++ /dev/null @@ -1,78 +0,0 @@ -//===- DistOps.h - Dist dialect -------------------------------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file declares the Dist dialect and its basic operations. -/// -//===----------------------------------------------------------------------===// - -#ifndef _Dist_OPS_H_INCLUDED_ -#define _Dist_OPS_H_INCLUDED_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace llvm { -template -hash_code hash_value(const SmallVector> &arg) { - hash_code hval((T)11); - for (const auto &v : arg) { - if (v.size()) - hval = - hash_combine(hval, hash_combine_range(v.data(), v.data() + v.size())); - } - return hval; -} -} // namespace llvm - -namespace imex { -namespace ndarray { -class NDArrayType; -} // namespace ndarray - -namespace dist { - -using ::mlir::DenseI64ArrayAttr; - -inline auto getBaseShardDimSize(int64_t shard, int64_t numShards, int64_t extend) { - return extend / numShards + (shard >= numShards - (extend % numShards) ? 1 : 0); -}; - -template -auto getBaseShardDimSize(T shard, T numShards, T extend) { - return extend / numShards + shard.sge(numShards - (extend % numShards)).select(1l, 0l); -}; - -template -auto getBaseShardDimOff(T shard, T numShards, T extend, T zero) { - return (shard * (extend / numShards)) + - (shard - (numShards - (extend % numShards))).max(zero); -}; - -} // namespace dist -} // namespace imex - -#include -#define GET_TYPEDEF_CLASSES -#include -#define GET_ATTRDEF_CLASSES -#include -#define GET_OP_CLASSES -#include - -#endif // _Dist_OPS_H_INCLUDED_ diff --git a/include/imex/Dialect/Dist/IR/DistOps.td b/include/imex/Dialect/Dist/IR/DistOps.td deleted file mode 100644 index a9331abd1..000000000 --- a/include/imex/Dialect/Dist/IR/DistOps.td +++ /dev/null @@ -1,639 +0,0 @@ -//===- DistOps.td - Dist dialect --------------------------*- tablegen -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines basic operations of the Dist dialect. -/// -//===----------------------------------------------------------------------===// - -#ifndef _Dist_OPS_TD_INCLUDED_ -#define _Dist_OPS_TD_INCLUDED_ - -include "mlir/IR/OpBase.td" -include "mlir/IR/AttrTypeBase.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/BuiltinTypeInterfaces.td" -include "mlir/Interfaces/InferTypeOpInterface.td" -include "mlir/IR/OpAsmInterface.td" -include "mlir/IR/BuiltinTypes.td" -include "mlir/IR/CommonAttrConstraints.td" -include "mlir/Dialect/Mesh/IR/MeshBase.td" - -// Provide a definition of the 'Dist' dialect in the ODS framework so that we -// can define our operations. -def Dist_Dialect : Dialect { - // The namespace of our dialect - let name = "dist"; - - // A short one-line summary of our dialect. - let summary = "A high-level dialect for distributing NDArray operations"; - - // A longer description of our dialect. - let description = [{ - This dialect provides basic features to allow automatic - partitioning and distribution of NDArrays. The dialect assumes SPMD execution - model. More specifically, each execution unit (or process) executes the same - program but locally owns only a part of the globally distributed data. There is - no central entity which partitions data or assigns work to workers. - - The Dist dialect is related to the NDArray and DistRuntime dialect. It is - expected that the Dist dialect will eventually get lowered to NDArray and - DistRuntime. - }]; - - let dependentDialects = [ - "::imex::ndarray::NDArrayDialect" - ]; - - // The C++ namespace that the dialect class definition resides in. - let cppNamespace = "::imex::dist"; - // let useDefaultTypePrinterParser = true; - let useDefaultAttributePrinterParser = true; - let useDefaultTypePrinterParser = 1; -} - - -class Dist_Attr traits = []> - : AttrDef { - let mnemonic = attrMnemonic; -} - -def DistSliceAttr : Dist_Attr<"DistSlice", "dist_slice"> { - let summary = "List of offsets, sizes and strides for a number of slices."; - let parameters = (ins - ArrayRefParameter<"DenseI64ArrayAttr">:$slcOffsets, - ArrayRefParameter<"DenseI64ArrayAttr">:$slcSizes, - ArrayRefParameter<"DenseI64ArrayAttr">:$slcStrides - // OptionalArrayRefParameter<"::mlir::DenseI64ArrayAttr">:$baseSizes - ); - let assemblyFormat = [{ - `<` `[` $slcOffsets `]` `[` $slcSizes `]` `[` $slcStrides `]` `>` - }]; -} - -def DistEnvAttr - : Dist_Attr<"DistEnv", "dist_env"> { - let summary = "Environment for indicating that a NDArray is distributed"; - let description = [{ - The environment attribute `DistEnvAttr` can be attached to NDArrays. - It carries the information required to describe the partitioning of the global NDArray. - - As an example, let's assume we want to equally distribute `ndarray.ndarray<33xi64>` across a team of 4. The third member of the team would attach label the type as a distributed array by attaching the following `DistEnvAttr`: - - `ndarray.ndarray<44xi64, #dist.dist_env>` - - This defines the following: - - * the array has global size `44` - * the array is distributed across team `37416` - * the local data starts at global index `22` - * the local part is of size `11` - * the size of the right and left halos is `0` - - Notice that `lparts` encodes the shapes of 3 parts that are held locally: - - 1. left halo - 2. locally owned data - 3. right halo - - All parts are of type `ndarray.NDArray`. Halos parts are copies of data owned by - remote team members. Parts always represent pieces of the global array resulting - from block-partitioning, i.e. they represent a contiguous block of the global - index space. Furthermore, the concatenation of left halo, local data and right - halo also represents a contiguous subset of the global index space. - - At this point, arrays are split only in the first dimension. A more general - scheme can be added once required. However, when more than one dimension is cut - it requires more than two halo parts and 'left' and 'right' are no longer - sufficient to describe their position relative to the locally owned data. - - Notice that any part can be empty - even the locally owned part. For example: a - subview of a global array might not intersect with the locally owned part. - - Parts and offsets are omitted (only) for 0d arrays: - `ndarray.ndarray` - - The local offset represents offsets in all dimensions, so in principle allows - partitions across multiple dimensions. For each dimension, the offset is - provided to the first part (in most cases that's the left halo). - - The offsets and sizes in `DistEnvAttr` can be static as in the example. - Alternatively, they can be partially or fully dynamic - even if the global size - is static. The above example with fully dynamic local offsets and sizes would become: - - `ndarray.ndarray<44xi64, #dist.dist_env>` - - There is no placeholder for unknown teams. - - The distribution metadata generalizes to arrays of arbitrary dimensions. Here is - an example of a distributed 2d array type: - - `ndarray.ndarray<44x55xi64, #dist.dist_env>` - - To indicate that the array is distributed across devices/GPUs, an additional - environment gets attached to `ndarray.ndarray`. The additional environment - defines on which device/GPU the local data should be stored. See NDArray spec - for details about GPU support. - - As an example, consider distributing an array across two GPUs in the same computer. The types could look like this - - * team member 0: - `ndarray.ndarray<22xi64, #dist.dist_env, #region.gpu_env>` - * team member 1: - `ndarray.ndarray<22xi64, #dist.dist_env, #region.gpu_env>` - }]; - - let parameters = (ins "::mlir::Attribute":$team, - ArrayRefParameter<"int64_t">:$lOffsets, - "::mlir::SmallVector<::mlir::SmallVector>":$parts_shapes); - - let assemblyFormat = "`<` custom($team, $lOffsets, $parts_shapes) `>`"; - - let builders = [ - AttrBuilderWithInferredContext<(ins "::mlir::Attribute":$team, - "::llvm::ArrayRef":$lOffsets, - "::mlir::SmallVector<::mlir::SmallVector>":$partsShapes)>, - AttrBuilderWithInferredContext<(ins "::mlir::Attribute":$team, "int64_t":$rank)> - ]; - - let extraClassDeclaration = [{ - DistEnvAttr cloneWithDynOffsAndDims() const; - }]; -} - - -// Base class for dialect operations. This operation inherits from the base -// `Op` class in OpBase.td, and provides: -// * The parent dialect of the operation. -// * The mnemonic for the operation, or the name without the dialect prefix. -// * A list of traits for the operation. -class Dist_Op traits = []> : - Op; - -def InitDistArrayOp : Dist_Op<"init_dist_array", [AttrSizedOperandSegments, Pure]> { - let summary = "Instantiate a distributed array, binding to distributed meta information."; - let description = [{ - Accepted dynamic distributed meta information: - - the local offset - - local parts - - The team and resulting global shape is encoded in the result type. - }]; - let arguments = (ins Variadic:$l_offset, Variadic:$parts); - let results = (outs AnyType); - - let assemblyFormat = [{ - oilist(`l_offset` $l_offset | `parts` $parts) attr-dict `:` qualified(type(operands)) `to` qualified(type(results)) - }]; - - let builders = [ - // auto-deduce return type - OpBuilder<(ins "::mlir::Attribute":$team, "::mlir::ArrayRef":$g_shape, "::mlir::ValueRange":$l_offset, - "::mlir::ValueRange":$parts, "::mlir::ArrayRef<::mlir::Attribute>":$environments, "::mlir::ArrayRef":$s_Offs)> -]; -} - -def LocalOffsetsOfOp : Dist_Op<"local_offsets_of", [Pure]> { - let summary = "Get local offsets of a distributed array."; - let description = [{ - Returns `rank`-many values, one for each dimension of `$array`. - }]; - let arguments = (ins AnyType:$array); - let results = (outs Variadic:$l_offsets); - let builders = [ - // autodeduce return type from from operands - OpBuilder<(ins "::mlir::Value":$array), [{ - auto rank = mlir::cast<::imex::ndarray::NDArrayType>(array.getType()).getRank(); - auto IndexType = $_builder.getIndexType(); - ::imex::TypVec rt(rank, IndexType); - build($_builder, $_state, ::mlir::TypeRange(rt), array); - }]>, - ]; -} - -def PartsOfOp : Dist_Op<"parts_of", [Pure]> { - let summary = "Get local parts of a distributed array."; - let description = [{ - Returns either one (0d array) or 3 parts - (all other cases: left halo, locally owned data, right halo) as - `ndarray.ndarray`. Returned arrays have the same rank as the input array. - }]; - let arguments = (ins AnyType:$array); - let results = (outs Variadic:$parts); - let builders = [OpBuilder<(ins "::mlir::Value":$array)>]; - let hasVerifier = 1; -} - -def DefaultPartitionOp : Dist_Op<"default_partition", [SameVariadicResultSize, Pure]> { - let summary = "Compute the default shape and offsets of the local partition."; - let description = [{ - All input and output shapes/offsets are vectors with same length. - - Arrays are cut along the first dimension and partitions are equally distributed - among all members of the team. Member "i" of the team gets assigned to part "i". - Odd elements in the cut dimension are equally distributed among the last team - members. This guarantees that the sizes of local parts differ by at most one - element in the cut dimension. - - For example, an array of size 8 will yield the local part sizes (2, 2, 2, 2) if - the team has 4 members. For a team of 3 it will render (2, 3, 3). - - Other partition strategies could be added later. - }]; - let arguments = (ins Index:$num_procs, Index:$p_rank, Variadic:$g_shape); - let results = (outs Variadic:$l_offsets, Variadic:$l_shape); - let builders = [ - // auto-deduce return type - OpBuilder<(ins "::mlir::Value":$num_procs, "::mlir::Value":$prank, "::mlir::ValueRange":$gshape), [{ - auto IndexType = $_builder.getIndexType(); - ::imex::TypVec rt(gshape.size()*2, IndexType); - build($_builder, $_state, ::mlir::TypeRange(rt), num_procs, prank, gshape); - }]>, - ]; -} - -def LocalTargetOfSliceOp : Dist_Op<"local_target_of_slice", - [SameVariadicOperandSize, SameVariadicResultSize, Pure]> { - let summary = "Compute local intersection of a distributed array with a slice."; - let description = [{ - This operation computes the intersection of the local part of the array and the - provided slice. The slice is provided as a triplet of offsets, sizes and strides - (similar to a subview). While the slice refers to the global index space of the - distributed array, the operation returns local offsets and sizes, relative to - the local part (e.g. these are not global indices). - - All input and output shapes/offsets/strides are `$array.rank()`-long vectors. - }]; - - let arguments = (ins - AnyType:$array, - Variadic:$offsets, - Variadic:$sizes, - Variadic:$strides - ); - let results = (outs Variadic:$t_offsets, Variadic:$t_sizes); - - let assemblyFormat = [{ - $array `[` $offsets `]``[` $sizes `]``[` $strides `]` attr-dict `:` qualified(type($array)) `to` qualified(type(results)) - }]; - - let builders = [ - // auto-deduce return type - OpBuilder<(ins "::mlir::Value":$array, "::mlir::ValueRange":$offsets, "::mlir::ValueRange":$sizes, "::mlir::ValueRange":$strides), [{ - auto IndexType = $_builder.getIndexType(); - ::imex::TypVec rt(offsets.size()*2, IndexType); - build($_builder, $_state, ::mlir::TypeRange(rt), array, offsets, sizes, strides); - }]>, - ]; -} - -def LocalBoundingBoxOp : Dist_Op<"local_bounding_box", [AttrSizedOperandSegments, SameVariadicResultSize, Pure]> { - let summary = "Compute (or extend) bounding box for data locally required by given view and target."; - let description = [{ - The locally required view is the intersection of the given view and target. - - If an existing bounding box is provided, update the bounding box. The update strategy is determined by the `inner` attribute: - - * if `inner` is unset (default) return the convex hull of given bounding box and - locally required view. - * else return the intersection of given bounding box and locally required view. - - If no bounding box is provided (through `b_b_offsets` and `b_b_sizes`) return the offset and shape of the locally required view. - - The bounding box is returned as global offsets and shape. - - All input and output shapes/offsets/strides are vectors with same length. - }]; - - let arguments = (ins I1Attr:$inner, - Variadic:$offsets, Variadic:$sizes, Variadic:$strides, - Variadic:$target_offsets, Variadic:$target_sizes, - Variadic:$b_b_offsets, Variadic:$b_b_sizes); - let results = (outs Variadic:$result_offsets, Variadic:$result_sizes); - - let assemblyFormat = [{ - $inner `[` $offsets `]``[` $sizes `]``[` $strides `]` `[` $target_offsets `]``[` $target_sizes `]` oilist(`bboffs` $b_b_offsets | `bb_sizes` $b_b_sizes) attr-dict `:` qualified(type(results)) - }]; - - let builders = [ - // auto-deduce return type: same as input - OpBuilder<(ins "bool":$inner, "::mlir::ValueRange":$offs, "::mlir::ValueRange":$sizes, "::mlir::ValueRange":$strides, - "::mlir::ValueRange":$toffs, "::mlir::ValueRange":$tsizes, - "::mlir::ValueRange":$bboffs, "::mlir::ValueRange":$bbsizes), [{ - size_t rank = offs.size(); - assert(sizes.size() == rank); - ::imex::TypVec rt(2 * rank, $_builder.getIndexType()); - build($_builder, $_state, ::mlir::TypeRange(rt), inner, offs, sizes, strides, toffs, tsizes, bboffs, bbsizes); - }]>, - ]; -} - -def LocalCoreOp : Dist_Op<"local_core", [AttrSizedOperandSegments, SameVariadicResultSize, Pure]> { - let summary = "Compute or update overlap of given core, locally owned data and locally required data."; - let description = [{ - The locally required view is the intersection of the given slice and target. - - If no local core is provided, return the intersection of locally owned data and - locally required data. Otherwise return the intersection of given core, locally - owned data and locally required data. - - The intersection is returned as global offsets and shape. - }]; - - let arguments = (ins AnyType:$array, - Variadic:$targetOffsets, Variadic:$targetSizes, - Variadic:$sliceOffsets, Variadic:$sliceSizes, Variadic:$sliceStrides, - Variadic:$coreOffsets, Variadic:$coreSizes); - let results = (outs Variadic:$resultOffsets, Variadic:$resultSizes); - - let assemblyFormat = [{ - $array oilist(`toffs` $targetOffsets | `tsizes` $targetSizes | `soffs` $sliceOffsets | `ssizes` $sliceSizes | `sstrides` $sliceStrides | `coffs` $coreOffsets | `csizes` $coreSizes) attr-dict `:` qualified(type($array)) `to` qualified(type(results)) - }]; - - let builders = [ - // auto-deduce return type: same as input - OpBuilder<(ins "::mlir::Value":$array, - "::mlir::ValueRange":$toffs, "::mlir::ValueRange":$tsizes, - "::mlir::ValueRange":$soffs, "::mlir::ValueRange":$ssizes ,"::mlir::ValueRange":$sstrides, - "::mlir::ValueRange":$coffs, "::mlir::ValueRange":$csizes), [{ - size_t rank = mlir::cast<::imex::ndarray::NDArrayType>(array.getType()).getRank(); - ::imex::TypVec rt(rank + rank, $_builder.getIndexType()); - build($_builder, $_state, ::mlir::TypeRange(rt), array, toffs, tsizes, soffs, ssizes, sstrides, coffs, csizes); - }]>, - ]; -} - -def RePartitionOp : Dist_Op<"repartition", [SameVariadicOperandSize, Pure]> { - let summary = "Repartition an array so that each process holds the requested data locally."; - let description = [{ - Creates a new NDArray by repartitioning the input array. It is assumed to be a - collective call. All participating processes request which part of the global - array they need. The halo parts of the returned array get filled with data that - is owned by remote team members. The local data is not modified, the returned - local part is a subview of the local part of the input. - - Target offset and target shape are optional arguments. If missing the operations - returns a default-partitioned array. - }]; - - let arguments = (ins AnyType:$array, - Variadic:$target_offsets, Variadic:$target_sizes); - let results = (outs AnyType); - - let assemblyFormat = [{ - $array oilist(`loffs` $target_offsets | `lsizes` $target_sizes) attr-dict `:` qualified(type(operands)) `to` qualified(type(results)) - }]; - - let builders = [ - // auto-deduce return type: same as input - OpBuilder<(ins "::mlir::Value":$array), [{ - build($_builder, $_state, array.getType(), array, {}, {}); - }]>, - ]; -} - - -// ============================================================================ -// (Extended) operations from NDArray -// ============================================================================ - -def SubviewOp : Dist_Op<"subview", [AttrSizedOperandSegments, Pure]> { - let summary = "Distributed extract slice operation"; - let description = [{ - The distributed subview operation is a shallow wrapper around NDArray.subview. - It extends the ndarray.SubviewOp with optional target offsets and target sizes. - }]; - - let arguments = (ins AnyType:$source, - Variadic:$offsets, - Variadic:$sizes, - Variadic:$strides, - DenseI64ArrayAttr:$static_offsets, - DenseI64ArrayAttr:$static_sizes, - DenseI64ArrayAttr:$static_strides, - Variadic:$target_offsets, - Variadic:$target_sizes - ); - let results = (outs AnyType:$result); - - let assemblyFormat = [{ - $source `` - custom($offsets, $static_offsets) - custom($sizes, $static_sizes) - custom($strides, $static_strides) - oilist(`toffs` $target_offsets | `tsizes` $target_sizes) - attr-dict `:` qualified(type($source)) `to` qualified(type($result)) - }]; -} - -def EWBinOp : Dist_Op<"ewbin", [Pure, SameVariadicOperandSize]> { - let summary = "Distributed elementwise binary operation"; - let description = [{ - The distributed EWBinOp is a shallow wrapper around NDArray.ewbinop. - It extends the ndarray.SubviewOp with optional core offsets, core sizes and target offsets. - }]; - - // ewbin takes 2 NDArrayType operands: lhs and rhs - let arguments = (ins AnyAttr:$op, AnyType:$lhs, AnyType:$rhs, - Variadic:$coreOffsets, Variadic:$coreSizes, - Variadic:$targetOffsets); - // result is a ndarray - let results = (outs AnyType); - let hasVerifier = 1; -} - -def EWUnyOp : Dist_Op<"ewuny", [Pure, SameVariadicOperandSize]> { - let summary = "Distributed elementwise unary operation"; - let description = [{ - The distributed EWUnyOp is a shallow wrapper around NDArray.ewunyop. - It extends the ndarray.EWUnyOp with optional core offsets, core sizes and target offsets. - }]; - - // ewuny takes 1 operand (NDArrayType) and one attribute (unary operation) - let arguments = (ins AnyAttr:$op, AnyType:$src, - Variadic:$coreOffsets, Variadic:$coreSizes, - Variadic:$targetOffsets); - // result is a ndarray - let results = (outs AnyType); -} - -//////////////////////////////////////////////////// -// shard - -// common base classes for types in NDArray dialect -class Dist_Type traits = [], - string baseCppClass = "::mlir::Type"> - : TypeDef { - let mnemonic = typeMnemonic; -} - -def Dist_ShardingConstraint : Dist_Type<"ShardingConstraint", "shardingconstraint"> { - let summary = "sharding constraint"; -} - -def ShardOp : Dist_Op<"shard", [Pure, - AllTypesMatch<["result", "src"]> - ]> { - let summary = "Annotate on how a tensor is sharded across a mesh."; - let description = [{ - The mesh.shard operation is designed to specify and guide the sharding - behavior of a tensor value across a mesh topology. This operation has one - operand and two attributes: - - 1. `input`: This operand represents the tensor value that needs to be - annotated for sharding. - - 2. `shard`: This attribute is type of `MeshSharding`, which is the core data - structure to represent distribution of a tensor on a mesh. - - 3. `annotate_for_users`: A unit attribute addressing the scenario when a - tensor's sharding annotation differs based on its context of use (either as - a result or an operand). If specified, the sharding pertains to specific - users of the tensor value, indicating how it should be considered when used - as an operand in subsequent operations. If not, the sharding applies to the - operation that defines the tensor value. - - Example: - ``` - func.func @only_result_annotated(%arg0 : tensor<4x8xf32>) -> () { - %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32> - ... - } - - func.func @only_operand_annotated(%arg0 : tensor<4x8xf32>) -> () { - %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32> - ... - } - - // The first mesh.shard op applies to %arg0, the second mesh.shard op - // applies for the operand of op0, the third mesh.shard op applies for the - // operand of op2 - func.func @both_result_and_multi_operands_annotated( - %arg0 : tensor<4x8xf32>) -> () { - %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32> - %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32> - %2 = mesh.shard %0 to <@mesh0, [[2]]> annotate_for_users : tensor<4x8xf32> - "op0"(%1) : ... - "op1"(%2) : ... - ... - } - ``` - - The following usages are undefined: - ``` - func.func @annotate_on_same_result_with_different_sharding( - %arg0 : tensor<4x8xf32>) -> () { - %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32> - %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32> - ... - } - - func.func @annotate_on_same_result_same_value_with_different_sharding( - %arg0 : tensor<4x8xf32>) -> () { - %0 = mesh.shard %arg0 to <@mesh0, [[0]]> : tensor<4x8xf32> - %1 = mesh.shard %arg0 to <@mesh0, [[1]]> : tensor<4x8xf32> - ... - } - - func.func @annotate_on_same_operand_with_different_sharding( - %arg0 : tensor<4x8xf32>) -> () { - %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32> - %1 = mesh.shard %0 to <@mesh0, [[1]]> annotate_for_users : tensor<4x8xf32> - ... - } - - func.func @result_annotated_after_operand( - %arg0 : tensor<4x8xf32>) -> () { - %0 = mesh.shard %arg0 to <@mesh0, [[0]]> annotate_for_users : tensor<4x8xf32> - %1 = mesh.shard %0 to <@mesh0, [[1]]> : tensor<4x8xf32> - ... - } - ``` - }]; - let arguments = (ins - AnyRankedTensor:$src, - AnyAttr:$shard, - UnitAttr:$annotate_for_users, - Optional:$constraint - ); - let results = (outs - AnyRankedTensor:$result - ); - let assemblyFormat = [{ - $src `to` ($constraint^)? attr-dict `:` type($result) - }]; -} - -def TargetOfSliceOp : Dist_Op<"target_of_slice", - [Pure]> { - let summary = "Compute intersections of a distributed array with a slice."; - let description = [{ - This operation computes the intersection of the shards of the array and the - provided slice. The slice is provided as a triplet of offsets, sizes and strides - (similar to a subview). - - All input shapes/offsets/strides are `$array.rank()`-long vectors. - The operation produces $array.rank()*split_axes.size() many values. - }]; - - let arguments = (ins - AnyType:$array, - DenseI64ArrayAttr:$static_offsets, - DenseI64ArrayAttr:$static_sizes, - DenseI64ArrayAttr:$static_strides, - FlatSymbolRefAttr:$mesh, - Mesh_MeshAxesArrayAttr:$split_axes - ); - let results = (outs Variadic:$result); - - let assemblyFormat = [{ - $array attr-dict `:` qualified(type($array)) `to` qualified(type(results)) - }]; -} - -def BoundingBoxOp : Dist_Op<"bounding_box", [SameVariadicOperandSize, Pure]> { - let summary = "Compute bounding box."; - - let arguments = (ins - DistSliceAttr:$static_slices, - Variadic:$subviewConstraints); - let results = (outs Dist_ShardingConstraint:$constraint); - - let assemblyFormat = [{ - `[` $subviewConstraints `]` attr-dict `:` qualified(type(results)) - }]; -} - -def ExtendHaloForSliceOp : Dist_Op<"extend_halo_for_slice", [Pure]> { - let summary = "Extend halo for slice."; - let description = [{ - This operation computes the halo required to compute the slice of the array - And extends provided halo if necessary. - The halo is returned as a flattened array. - }]; - - let arguments = (ins - DenseI64ArrayAttr:$static_shape, - FlatSymbolRefAttr:$mesh, - Mesh_MeshAxesArrayAttr:$split_axes, - Variadic:$halo_sizes, - DenseI64ArrayAttr:$static_offsets, - DenseI64ArrayAttr:$static_sizes, - DenseI64ArrayAttr:$static_strides, - DenseI64ArrayAttr:$sharded_dims_offsets - ); - let results = (outs Variadic:$result); - - let assemblyFormat = [{ - `[` $halo_sizes `]` attr-dict `:` qualified(type(results)) - }]; -} - - -#endif // _Dist_OPS_TD_INCLUDED_ diff --git a/include/imex/Dialect/Dist/Transforms/CMakeLists.txt b/include/imex/Dialect/Dist/Transforms/CMakeLists.txt deleted file mode 100644 index 2e11fd05a..000000000 --- a/include/imex/Dialect/Dist/Transforms/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name Dist) -mlir_tablegen(Passes.capi.h.inc -gen-pass-capi-header --prefix Dist) -mlir_tablegen(Passes.capi.cpp.inc -gen-pass-capi-impl --prefix Dist) -add_public_tablegen_target(IMEXDistPassIncGen) - -add_mlir_doc(Passes DistPasses ./ -gen-pass-doc) diff --git a/include/imex/Dialect/Dist/Transforms/Passes.h b/include/imex/Dialect/Dist/Transforms/Passes.h deleted file mode 100644 index 253acf61a..000000000 --- a/include/imex/Dialect/Dist/Transforms/Passes.h +++ /dev/null @@ -1,53 +0,0 @@ -//===-- Passes.h - Dist pass declaration file -------------------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This header file defines prototypes that expose pass constructors for the -/// Dist dialect. -/// -//===----------------------------------------------------------------------===// - -#ifndef _Dist_PASSES_H_INCLUDED_ -#define _Dist_PASSES_H_INCLUDED_ - -#include - -namespace mlir { -class LLVMTypeConverter; -class MLIRContext; -class ModuleOp; -template class OperationPass; -class RewritePatternSet; -} // namespace mlir - -namespace imex { - -//===----------------------------------------------------------------------===// -/// Dist passes. -//===----------------------------------------------------------------------===// - -/// Create a DistCoalesce pass -std::unique_ptr<::mlir::Pass> createDistCoalescePass(); -/// Create DistInferEWBinopPass -std::unique_ptr<::mlir::Pass> createDistInferEWCoresPass(); - -#define GEN_PASS_DECL -#include - -//===----------------------------------------------------------------------===// -// Registration -//===----------------------------------------------------------------------===// - -/// Generate the code for registering passes. -#define GEN_PASS_REGISTRATION -#include - -} // namespace imex - -#endif // _Dist_PASSES_H_INCLUDED_ diff --git a/include/imex/Dialect/Dist/Transforms/Passes.td b/include/imex/Dialect/Dist/Transforms/Passes.td deleted file mode 100644 index 9da974712..000000000 --- a/include/imex/Dialect/Dist/Transforms/Passes.td +++ /dev/null @@ -1,53 +0,0 @@ -//===-- Passes.td - Dist pass definition file --------------*- tablegen -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file defines passes/transformations of the Dist dialect. -/// -//===----------------------------------------------------------------------===// - -#ifndef _Dist_PASSES_TD_INCLUDED_ -#define _Dist_PASSES_TD_INCLUDED_ - -include "mlir/Pass/PassBase.td" - -//===----------------------------------------------------------------------===// -// DistCoalesce -//===----------------------------------------------------------------------===// - -def DistCoalesce : Pass<"dist-coalesce", "::mlir::func::FuncOp"> { - let summary = "Coalesce operations from Dist dialect."; - let description = [{ - TODO - }]; - let constructor = "imex::createDistCoalescePass()"; - let dependentDialects = ["::imex::dist::DistDialect", - "::imex::distruntime::DistRuntimeDialect", - "::mlir::arith::ArithDialect", - "::mlir::tensor::TensorDialect", - "::mlir::memref::MemRefDialect"]; - let options = [ - Option<"in_jit", "in-jit", "bool", /*default=*/"true", - "Assume (or not) that pass is run within a jit.">, - ]; -} - -def DistInferEWCores : Pass<"dist-infer-elementwise-cores", "::mlir::func::FuncOp"> { - let summary = "Add core for dependent elementwise operations."; - let description = [{ - Distributed tensors can have non-contiguous data. Elementwise operations on - shifted views therefore lead to multiple loops with different shapes which prevents - loop fusion. This pass tries to compute the intersection of loop boundaries for a series of - dependent elementwise operations and adds this information to the respective ops. - }]; - let constructor = "imex::createDistInferEWCoresPass()"; - let dependentDialects = ["::imex::dist::DistDialect"]; -} - -#endif // _Dist_PASSES_TD_INCLUDED_ diff --git a/include/imex/Dialect/Dist/Utils/CMakeLists.txt b/include/imex/Dialect/Dist/Utils/CMakeLists.txt deleted file mode 100644 index d941185cf..000000000 --- a/include/imex/Dialect/Dist/Utils/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -install(FILES Utils.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/imex/Dialect/Dist/Utils) diff --git a/include/imex/Dialect/Dist/Utils/Utils.h b/include/imex/Dialect/Dist/Utils/Utils.h deleted file mode 100644 index 05321de0e..000000000 --- a/include/imex/Dialect/Dist/Utils/Utils.h +++ /dev/null @@ -1,337 +0,0 @@ -//===- Utils.h - Utils for Dist dialect -----------------------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file declares the utils for the dist dialect. -/// -//===----------------------------------------------------------------------===// - -#ifndef _DIST_UTILS_H_INCLUDED_ -#define _DIST_UTILS_H_INCLUDED_ - -#include -#include -#include -#include -#include - -#include -#include - -namespace imex { -namespace dist { - -// ******************************* -// ***** Some helper functions *** -// ******************************* - -/// @return true if atribute is a DistEnvAttr -inline bool isDist(const ::mlir::Attribute &a) { - return ::mlir::isa<::imex::dist::DistEnvAttr>(a); -} - -/// @return true if type has a DistEnvAttr -inline bool isDist(const ::imex::ndarray::NDArrayType &t) { - return ::imex::ndarray::hasEnv<::imex::dist::DistEnvAttr>(t); -} - -/// @return true if type has a DistEnvAttr -inline bool isDist(const ::mlir::Type &t) { - auto arType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(t); - return arType ? isDist(arType) : false; -} - -/// @return true if value is a DistEnvAttr -inline bool isDist(const ::mlir::Value &v) { - auto arType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(v.getType()); - return arType ? isDist(arType) : false; -} - -/// @return first DistEnvAttr, null-attr if none exists -inline ::imex::dist::DistEnvAttr -getDistEnv(const ::imex::ndarray::NDArrayType &t) { - for (auto a : t.getEnvironments()) { - if (auto d = ::mlir::dyn_cast<::imex::dist::DistEnvAttr>(a)) { - return d; - } - } - return {}; -} - -/// @return return NDArray's env attributes except DistEnvAttrs -inline ::mlir::SmallVector<::mlir::Attribute> -getNonDistEnvs(const ::imex::ndarray::NDArrayType &t) { - ::mlir::SmallVector<::mlir::Attribute> envs; - std::copy_if(t.getEnvironments().begin(), t.getEnvironments().end(), - std::back_inserter(envs), [](auto i) { return !isDist(i); }); - return envs; -} - -/// @return clone of type, but with dynamic shapes (local and global) -inline ::imex::ndarray::NDArrayType -cloneWithDynEnv(const ::imex::ndarray::NDArrayType &ary) { - auto oEnvs = ary.getEnvironments(); - ::mlir::SmallVector<::mlir::Attribute> envs; - for (auto e : oEnvs) { - if (auto a = ::mlir::dyn_cast<::imex::dist::DistEnvAttr>(e)) { - e = a.cloneWithDynOffsAndDims(); - } - envs.emplace_back(e); - } - return ::imex::ndarray::NDArrayType::get(ary.getShape(), ary.getElementType(), - envs); -} - -/// @return clone of type, but with dynamic shapes (local and global) -inline ::imex::ndarray::NDArrayType -cloneWithShape(const ::imex::ndarray::NDArrayType &ary, - const ::mlir::ArrayRef shape) { - auto oEnvs = ary.getEnvironments(); - ::mlir::SmallVector<::mlir::Attribute> envs; - for (auto e : oEnvs) { - if (auto a = ::mlir::dyn_cast<::imex::dist::DistEnvAttr>(e)) { - e = ::imex::dist::DistEnvAttr::get(a.getTeam(), shape.size()); - } - envs.emplace_back(e); - } - return ::imex::ndarray::NDArrayType::get(shape, ary.getElementType(), envs); -} - -/// @return clone of type, but without dist-env and with dynamic shapes -inline ::imex::ndarray::NDArrayType -cloneAsDynNonDist(const ::imex::ndarray::NDArrayType &ary) { - auto envs = getNonDistEnvs(ary); - if (ary.hasUnitSize()) { - return ::imex::ndarray::NDArrayType::get(ary.getShape(), - ary.getElementType(), envs); - } else { - return ::imex::ndarray::NDArrayType::get( - ::mlir::SmallVector(ary.getRank(), - ::mlir::ShapedType::kDynamic), - ary.getElementType(), envs); - } -} - -/// @return clone of type, but without dist-env -inline ::imex::ndarray::NDArrayType -cloneAsNonDist(const ::imex::ndarray::NDArrayType &ary) { - auto envs = getNonDistEnvs(ary); - return ::imex::ndarray::NDArrayType::get(ary.getShape(), ary.getElementType(), - envs); -} - -/// @return types of NDArray's parts if it has distenv, empty vector otherwise -inline ::imex::TypVec -getPartsTypes(const ::imex::ndarray::NDArrayType &arType) { - ::imex::TypVec res; - if (auto dEnv = getDistEnv(arType)) { - auto shapes = dEnv.getPartsShapes(); - auto envs = getNonDistEnvs(arType); - if (arType.getRank() == 0 && shapes.size() == 0) { - res.emplace_back(::imex::ndarray::NDArrayType::get( - {1}, arType.getElementType(), envs)); // FIXME env layout - } else { - for (auto shp : shapes) { - res.emplace_back(::imex::ndarray::NDArrayType::get( - shp, arType.getElementType(), envs)); // FIXME env layout - } - } - } - return res; -} - -/// Create a distributed array from a NDArray and meta data -inline ::mlir::Value -createDistArray(const ::mlir::Location &loc, ::mlir::OpBuilder &builder, - ::mlir::Attribute team, ::mlir::ArrayRef gshape, - ::mlir::ValueRange loffs, ::mlir::ValueRange parts, - ::mlir::ArrayRef sOffs = {}) { - assert(parts.size() == 1 || parts.size() == 3); - ::imex::ValVec nParts; - auto p = mlir::cast<::imex::ndarray::NDArrayType>(parts.front().getType()); - auto envs = p.getEnvironments(); - auto rank = p.getRank(); - assert(rank || parts.size() == 1); - - if (parts.size() == 1 && rank) { - auto elType = p.getElementType(); - ::imex::ValVec shp(rank, createIndex(loc, builder, 0)); - auto lHalo = builder.create<::imex::ndarray::CreateOp>( - loc, shp, ::imex::ndarray::fromMLIR(elType), nullptr, envs); - auto rHalo = builder.create<::imex::ndarray::CreateOp>( - loc, shp, ::imex::ndarray::fromMLIR(elType), nullptr, envs); - nParts = {lHalo, parts.front(), rHalo}; - } else { - nParts = parts; - assert(parts.size() == 3 || rank == 0); - } - - for (auto x : parts) - assert(!isDist(x)); - - return builder.create<::imex::dist::InitDistArrayOp>(loc, team, gshape, loffs, - nParts, envs, sOffs); -} - -inline ::mlir::Value -createDistArray(const ::mlir::Location &loc, ::mlir::OpBuilder &builder, - ::mlir::Attribute team, ::imex::ValVec gshape, - ::mlir::ValueRange loffs, ::mlir::ValueRange parts) { - auto gshp = mkConstant(gshape); - return createDistArray(loc, builder, team, gshp, loffs, parts); -} - -// create operation returning global shape of distributed array -inline ::imex::ValVec createGlobalShapeOf(const ::mlir::Location &loc, - ::mlir::OpBuilder &builder, - ::mlir::Value ary) { - auto gshp = - mlir::cast<::imex::ndarray::NDArrayType>(ary.getType()).getShape(); - ::imex::ValVec res; - for (auto d : gshp) { - res.emplace_back(createIndex(loc, builder, d)); - } - return res; -} - -// create operation returning local offsets of distributed array -inline ::mlir::ValueRange createLocalOffsetsOf(const ::mlir::Location &loc, - ::mlir::OpBuilder &builder, - ::mlir::Value ary) { - return builder.create<::imex::dist::LocalOffsetsOfOp>(loc, ary).getLOffsets(); -} - -// create operation returning all parts (owned + halos) of distributed array -inline ::mlir::ValueRange createPartsOf(const ::mlir::Location &loc, - ::mlir::OpBuilder &builder, - ::mlir::Value ary) { - return builder.create<::imex::dist::PartsOfOp>(loc, ary).getParts(); -} - -inline ::mlir::Value createNProcs(const ::mlir::Location &loc, - ::mlir::OpBuilder &builder, - ::mlir::Attribute team) { - return builder.createOrFold<::imex::distruntime::TeamSizeOp>(loc, team); -} - -inline ::mlir::Value createPRank(const ::mlir::Location &loc, - ::mlir::OpBuilder &builder, - ::mlir::Attribute team) { - return builder.createOrFold<::imex::distruntime::TeamMemberOp>(loc, team); -} - -// create operation returning the re-partitioned array -inline ::mlir::Value createRePartition(const ::mlir::Location &loc, - ::mlir::OpBuilder &builder, - ::mlir::Value ary, - const ::mlir::ValueRange &tOffs = {}, - const ::mlir::ValueRange &tSzs = {}) { - auto retTyp = mlir::cast<::imex::ndarray::NDArrayType>(ary.getType()) - .cloneWithDynDims(); - return builder.create<::imex::dist::RePartitionOp>(loc, retTyp, ary, tOffs, - tSzs); -} - -inline auto createDefaultPartition(const ::mlir::Location &loc, - ::mlir::OpBuilder &builder, - ::mlir::Attribute team, - ::imex::ValVec gShape) { - auto nProcs = createNProcs(loc, builder, team); - auto pRank = createPRank(loc, builder, team); - return builder.create<::imex::dist::DefaultPartitionOp>(loc, nProcs, pRank, - gShape); -} - -template static T _min(const T &a, const T &b) { - return std::min(a, b); -} -template static T _max(const T &a, const T &b) { - return std::max(a, b); -} -template static T _get(const T &a) { return a; }; -template struct _gen { - template - T operator()(::mlir::Location, ::mlir::OpBuilder, const U &a) { - return static_cast(a); - } -}; - -[[maybe_unused]] static EasyIdx _min(const EasyIdx &a, const EasyIdx &b) { - return a.min(b); -} -[[maybe_unused]] static EasyIdx _max(const EasyIdx &a, const EasyIdx &b) { - return a.max(b); -} -[[maybe_unused]] static EasyIdx::ElType _get(const EasyIdx &a) { - return a.get(); -}; -template <> struct _gen { - template - EasyIdx operator()(::mlir::Location loc, ::mlir::OpBuilder rewriter, - const U &a) { - return easyIdx(loc, rewriter, a); - } -}; - -/// @brief compute overlap of given slices with local off/shape -/// @param lShape local shape -/// @param lOff local offset -/// @param slcOff slice's offset -/// @param slcSize slice's size -/// @param slcStride slice's stride -/// @return offsets and sizes of overlap and leading/skipped elements of slice -template -inline std::tuple -createOverlap(::mlir::Location loc, ::mlir::OpBuilder rewriter, - const ::mlir::ValueRange &lOffs, const ::mlir::ValueRange &lShape, - const I &slcOffs, const I &slcSizes, const I &slcStrides, - size_t rank = 0) { - rank = rank ? rank : lShape.size(); - auto mygen = _gen(); - auto zero = mygen(loc, rewriter, 0); - auto one = mygen(loc, rewriter, 1); - - R resOffs(rank, _get(zero)); - R resSlcOffs(rank, _get(zero)); - R resSizes(slcSizes.begin(), slcSizes.end()); - - for (unsigned i = 0; i < rank; ++i) { - // Get the vals from dim - auto lOff = mygen(loc, rewriter, lOffs[i]); - auto slcOff = mygen(loc, rewriter, slcOffs[i]); - auto slcStride = mygen(loc, rewriter, slcStrides[i]); - auto slcSize = mygen(loc, rewriter, slcSizes[i]); - auto lSize = mygen(loc, rewriter, lShape[i]); - - // last index of slice - auto slcEnd = slcOff + slcSize * slcStride; - // last index of local partition - auto lEnd = lOff + lSize; - - auto maxOff = lOff.max(slcOff); - auto stride_1 = slcStride - one; - // slc { class DistRuntime_Op traits = []> : Op; -def TeamSizeOp : DistRuntime_Op<"team_size", [Pure]> { - let summary = "Get number of members in given team"; - let description = [{ - Operation that returns the number of team members of a given team. - }]; - let arguments = (ins AnyAttr:$team); - let results = (outs Index); - let builders = [ - OpBuilder<(ins "::mlir::Attribute":$team), [{ - build($_builder, $_state, $_builder.getIndexType(), team); - }]>, - ]; - // to be defined by the lowerer - let hasFolder = 1; -} - -def TeamMemberOp : DistRuntime_Op<"team_member", [Pure]> { - let summary = "Get member of given team"; - let description = [{ - Operation that returns the member of the given team that the caller - represents. Members of a team are represented as Index types. - }]; - let arguments = (ins AnyAttr:$team); - let results = (outs Index); - let builders = [ - OpBuilder<(ins "::mlir::Attribute":$team), [{ - build($_builder, $_state, $_builder.getIndexType(), team); - }]>, - ]; - // to be defined by the lowerer - let hasFolder = 1; -} - -def AllReduceOp : DistRuntime_Op<"allreduce", []> { - let summary = "Inplace allreduce"; - let description = [{ - Operation that performs an in-place all-reduce with a given operation. - The shape of the data argument must be identical for all members of the team. - Reduction happens for each element of the argument over all team members. - The meaning of the 'op' attribute is defined by the lowering passes. - }]; - // reduction operation and local tensor - let arguments = (ins AnyAttr:$op, AnyMemRef:$data); -} - -def GetHaloOp : DistRuntime_Op<"get_halo", - [SameVariadicOperandSize, DeclareOpInterfaceMethods]> { - let summary = "Get left and right halos"; - let description = [{ - For a given, distributed array, compute and return the left and right - halo as implied by the locally owned data and requested bounding box. - - Data that is not locally owned will be provided in the left or right - halo, depending on if the data is from before or after the local part - in the first dimension of the global array. Hence it is possible that - one or both returned halos are empty. - - The local data is not modified. - - Arguments: - - - `local`: the locally owned data - - `gShape`: the global shape of the distributed array - - `lOffsets`: the offset of the local data within the global array - - `bbOffsets`: the offsets of the requested data part - - `bbSizes`: the shape of the requested data part - - `team`: the distributed team owning the distributed array - - `key` [optional]: a statically assigned id for the given operation (to allow caching) - - `gShape`, `lOffsets`, `bbOffsets` and `bbSizes` are variadic arguments - with same size `r` where `r` is the rank of the global array (e.g., one - number for each dimension of the global array). - - Returns an `AsyncHandle`, the left and the right halo. - }]; - let arguments = (ins AnyType:$local, Variadic:$gShape, Variadic:$lOffsets, - Variadic:$bbOffsets, Variadic:$bbSizes, - AnyAttr:$team, DefaultValuedAttr:$key); - let results = (outs DistRuntime_AsyncHandle:$handle, AnyType:$lHalo, AnyType:$rHalo); - - let builders = [ - // auto-deduce return type - OpBuilder<(ins "::mlir::Value":$local, "::mlir::ValueRange":$gShape, "::mlir::ValueRange":$lOffsets, - "::mlir::ValueRange":$bbOffsets, "::mlir::ValueRange":$bbSizes, - "::mlir::ValueRange":$lHSizes, "::mlir::ValueRange":$rHSizes, - "::mlir::Attribute":$team, CArg<"int64_t", "-1L">:$key)> - ]; - let hasCanonicalizer = 1; -} - def CopyReshapeOp : DistRuntime_Op<"copy_reshape", [Pure, DeclareOpInterfaceMethods, AttrSizedOperandSegments]> { let summary = "Copy adequate data from input to a new reshaped output"; diff --git a/include/imex/Dialect/DistRuntime/Transforms/Passes.h b/include/imex/Dialect/DistRuntime/Transforms/Passes.h index 0b0c0eccd..5d2abc634 100644 --- a/include/imex/Dialect/DistRuntime/Transforms/Passes.h +++ b/include/imex/Dialect/DistRuntime/Transforms/Passes.h @@ -33,11 +33,7 @@ namespace imex { //===----------------------------------------------------------------------===// std::unique_ptr<::mlir::Pass> createDistRuntimeToIDTRPass(); -std::unique_ptr<::mlir::Pass> createOverlapCommAndComputePass(); -std::unique_ptr<::mlir::Pass> createAddCommCacheKeysPass(); -#define GEN_PASS_DECL_OVERLAPCOMMANDCOMPUTE -#define GEN_PASS_DECL_ADDCOMMCACHEKEYS #define GEN_PASS_DECL_LOWERDISTRUNTIMETOIDTR #include diff --git a/include/imex/Dialect/DistRuntime/Transforms/Passes.td b/include/imex/Dialect/DistRuntime/Transforms/Passes.td index 229b1e505..654c09605 100644 --- a/include/imex/Dialect/DistRuntime/Transforms/Passes.td +++ b/include/imex/Dialect/DistRuntime/Transforms/Passes.td @@ -31,19 +31,4 @@ def LowerDistRuntimeToIDTR: Pass<"lower-distruntime-to-idtr"> { let options = []; } -def OverlapCommAndCompute : Pass<"overlap-comm-and-compute"> { - let summary = "Try to make asynchronous communication overlap some computation."; - let constructor = "imex::createOverlapCommAndComputePass()"; - let dependentDialects = ["::imex::ndarray::NDArrayDialect", - "::imex::distruntime::DistRuntimeDialect"]; - let options = []; -} - -def AddCommCacheKeys : Pass<"add-comm-cache-keys"> { - let summary = "Add unique keys to each distruntime.udpate_halo op."; - let constructor = "imex::createAddCommCacheKeysPass()"; - let dependentDialects = []; - let options = []; -} - #endif // _DistRuntime_PASSES_TD_INCLUDED_ diff --git a/include/imex/Dialect/NDArray/IR/CMakeLists.txt b/include/imex/Dialect/NDArray/IR/CMakeLists.txt index e942c4d90..6cdedf49d 100644 --- a/include/imex/Dialect/NDArray/IR/CMakeLists.txt +++ b/include/imex/Dialect/NDArray/IR/CMakeLists.txt @@ -1,2 +1,8 @@ add_mlir_dialect(NDArrayOps ndarray) add_mlir_doc(NDArrayOps NDArrayDialect Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS NDArrayOps.td) +mlir_tablegen(NDArrayOpsAttrs.h.inc -gen-attrdef-decls -attrdefs-dialect=ndarray) +mlir_tablegen(NDArrayOpsAttrs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=ndarray) +add_public_tablegen_target(MLIRNDArrayIncGen) +add_dependencies(mlir-headers MLIRNDArrayIncGen) \ No newline at end of file diff --git a/include/imex/Dialect/NDArray/IR/NDArrayOps.h b/include/imex/Dialect/NDArray/IR/NDArrayOps.h index 7d971f36a..f4c578059 100644 --- a/include/imex/Dialect/NDArray/IR/NDArrayOps.h +++ b/include/imex/Dialect/NDArray/IR/NDArrayOps.h @@ -71,6 +71,8 @@ class NDArrayBase : public mlir::Type, #include #define GET_TYPEDEF_CLASSES #include +#define GET_ATTRDEF_CLASSES +#include #define GET_OP_CLASSES #include @@ -80,43 +82,45 @@ namespace imex { namespace ndarray { /// @return true if given NDArrayTYpe has this specific environment attribute -template bool hasEnv(const ::imex::ndarray::NDArrayType &t) { - for (auto a : t.getEnvironments()) { - if (::mlir::isa(a)) { - return true; +template bool hasEnv(const ::mlir::RankedTensorType &t) { + auto encoding = t.getEncoding(); + if (auto envs = ::mlir::dyn_cast(encoding)) { + for (auto a : envs.getEnvs()) { + if (::mlir::isa(a)) { + return true; + } } } return false; } inline bool hasGPUEnv(const ::mlir::Type &t) { - auto ptType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(t); + auto ptType = mlir::dyn_cast<::mlir::RankedTensorType>(t); return ptType ? ::imex::ndarray::hasEnv<::imex::region::GPUEnvAttr>(ptType) : false; } inline ::imex::region::GPUEnvAttr getGPUEnv(const ::mlir::Type &t) { - auto ptType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(t); - if (ptType) { - for (auto a : ptType.getEnvironments()) { - if (auto g = ::mlir::dyn_cast<::imex::region::GPUEnvAttr>(a)) { - return g; + if (auto tt = ::mlir::dyn_cast<::mlir::RankedTensorType>(t)) { + auto encoding = tt.getEncoding(); + if (auto envs = ::mlir::dyn_cast(encoding)) { + for (auto a : envs.getEnvs()) { + if (auto g = ::mlir::dyn_cast<::imex::region::GPUEnvAttr>(a)) { + return g; + } } } } return {}; } -// Determine whether CastOp casts to a nore dynamic version of the source tensor -bool canFoldIntoConsumerOp(CastOp castOp); -bool canFoldIntoConsumerOp(::mlir::tensor::CastOp castOp); - /// Performs folding of any operand of `op` if it comes from a ndarray::CastOp /// that can be folded. mlir::LogicalResult foldArrayCast(mlir::Operation *op); /// @return true if shape is known to span exactly one element bool isUnitShape(const llvm::ArrayRef shp); +bool hasZeroSize(const llvm::ArrayRef shp); } // namespace ndarray } // namespace imex diff --git a/include/imex/Dialect/NDArray/IR/NDArrayOps.td b/include/imex/Dialect/NDArray/IR/NDArrayOps.td index ebefc172e..284387f96 100644 --- a/include/imex/Dialect/NDArray/IR/NDArrayOps.td +++ b/include/imex/Dialect/NDArray/IR/NDArrayOps.td @@ -23,133 +23,47 @@ include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/BuiltinTypeInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/ViewLikeInterface.td" -include "mlir/Interfaces/ShapedOpInterfaces.td" -include "mlir/Interfaces/DestinationStyleOpInterface.td" +// include "mlir/Interfaces/ShapedOpInterfaces.td" +// include "mlir/Interfaces/DestinationStyleOpInterface.td" include "mlir/IR/OpAsmInterface.td" include "mlir/Interfaces/CastInterfaces.td" -// Provide a definition of the 'NDArray' so that we can define our operations. def NDArray_Dialect : Dialect { let name = "ndarray"; let cppNamespace = "::imex::ndarray"; let summary = "A high-level dialect for parallel tensor operations"; let description = [{ - The ndarray dialect describes parallel operations on tensors. - Generic parallel patterns are provided, such as element-wise-unary, - element-wise-binary or reduce. + The ndarray dialect describes high-lvel operations on arrays. + It extends the tenssor and linalg dialects with operations which + have array specific semantics, like mutating operations. - Generally the NDArray dialect is intended to provide high-level abstractions - to allow compute-follows-data semantics. For this the NDArrayType constitutes - a ranked tensor with information about the location (device, team) of - the tensor-data when NDArrays are created. - - The functional scope of the dialect is the - [array-API](https://data-apis.org/array-api/latest/index.html). - - The NDArray differs from tensor dialects in MLIR because it + The dialects differs from tensor dialects in MLIR because it it is meant to allow operations with in-place semantics and creating subviews which are guaranteed to be views. - }]; - - // We use the default parser/printer which handles registered types - let useDefaultTypePrinterParser = true; - let hasConstantMaterializer = 1; -} - - -// common base classes for types in NDArray dialect -class NDArray_Type traits = [], - string baseCppClass = "::mlir::Type"> - : TypeDef { - let mnemonic = typeMnemonic; -} - - -def NDArray_NDArray : NDArray_Type<"NDArray", "ndarray", [ShapedTypeInterface], - "::imex::ndarray::NDArrayBase"> { - let summary = "Multi-dimensional numpy-like array"; - let description = [{ - Multi-dimensional numpy-like array. - - Contrary to upstream tensor type is has a reference semantics and allow to - modify data inplace. - - The NDArray has a dynamic shape and an element-type. - - Additionally it has optional `environment` attribute, which specifies additional environment - information for computations on this tensor. One or more environments define - the location of the data and therefore where operations are expected to be - executed. For example, a GPUEnvAttr can be attached to indicate the array is - expected to be allocated on a (sepcific) GPU. Similarly, a DistEnvAttr would - annotate the array to be distributed. - - Examples: - A 6x6 array of 32bit ints on the first GPU device through the OpenCL backend could look like this: - `!ndarray.ndarray<6x6xi32, #region.gpu_env>` - - A distributed 6x6 array of 32bit ints, with team `22` and which locally owns the last 3 rows: - `!ndarray.ndarray<6x6xi32, #dist.dist_env>` - - Combining the above two yields a distributed array where the local part is assigned to a GPU: - `!ndarray.ndarray<6x6xi32, #dist.dist_env, #region.gpu_env>` - }]; - - let parameters = (ins - ArrayRefParameter<"int64_t">:$shape, // array shape, allows dynamic dims - "::mlir::Type":$elementType, // element type - OptionalArrayRefParameter<"::mlir::Attribute">:$environments, // environments - OptionalParameter<"::mlir::StringAttr">:$layout // layout - ); - let assemblyFormat = [{ - `<` custom($shape, $elementType) (`:` $layout^)? (`,` $environments^)? `>` - }]; + Generally the ndarray dialect is intended to provide high-level + abstractions to allow compute-follows-data semantics. For this, + the dialect operates on ranked tensors and attaches information + about the location (device, team) of the tensor-data when + arrays are created. These annotations are done through the + mesh dialect. - let skipDefaultBuilders = 1; - let builders = [ - // inferred type - TypeBuilder<(ins - "::llvm::ArrayRef":$shape, - "::mlir::Type":$elementType, - "::mlir::ArrayRef<::mlir::Attribute>":$environments, - "::mlir::StringAttr":$layout)>, - // inferred type, optional environments and layout - TypeBuilderWithInferredContext<(ins - "::llvm::ArrayRef":$shape, - "::mlir::Type":$elementType, - CArg<"::mlir::ArrayRef<::mlir::Attribute>", "{}">:$environments, - CArg<"std::optional<::llvm::StringRef>", "std::nullopt">:$layout - )>, - // inferred type, no context - TypeBuilderWithInferredContext<(ins - "::llvm::ArrayRef":$shape, - "::mlir::Type":$elementType, - "::mlir::ArrayRef<::mlir::Attribute>":$environments, - "::mlir::StringAttr":$layout - )> - ]; + The functional scope of the dialect (together with tensor and + linalg dialects) is the + [array-API](https://data-apis.org/array-api/latest/index.html). - let extraClassDeclaration = [{ - ::mlir::MemRefType getMemRefType(::mlir::Value = {}) const; - ::mlir::RankedTensorType getTensorType() const; - ::imex::ndarray::NDArrayType cloneWithDynDims() const; - bool hasUnitSize() const; - bool hasZeroSize() const; - // ShapedTypeInterface - using ::mlir::ShapedType::Trait::clone; - using ::mlir::ShapedType::Trait::getElementTypeBitWidth; - using ::mlir::ShapedType::Trait::getRank; - using ::mlir::ShapedType::Trait::getNumElements; - using ::mlir::ShapedType::Trait::isDynamicDim; - using ::mlir::ShapedType::Trait::hasStaticShape; - using ::mlir::ShapedType::Trait::getNumDynamicDims; - using ::mlir::ShapedType::Trait::getDimSize; - using ::mlir::ShapedType::Trait::getDynamicDimIndex; - }]; + }]; + // We use the default parser/printer which handles registered attrs + let useDefaultAttributePrinterParser = true; } +def NDArray_EnvironmentAttr : AttrDef { + let mnemonic = "environment"; + let parameters = (ins ArrayRefParameter<"::mlir::Attribute">:$envs); + let assemblyFormat = "`<` $envs `>`"; +} // Base class for dialect operations. This operation inherits from the base // `Op` class in OpBase.td, and provides: @@ -160,101 +74,38 @@ class NDArray_Op traits = []> : Op; -def DeleteOp : NDArray_Op<"delete", [ - DeclareOpInterfaceMethods]> { - let summary = "Explicitly delete an NDArray, freeing its memory"; - let description = [{ - Allow explicitly deleting the memory of an NDArray. It is assumed - that the memory had been allocated by one of NDArray's creation functions. - It must be the last use of the input array. - }]; - - let arguments = (ins NDArray_NDArray:$input); - - let assemblyFormat = [{ - $input attr-dict `:` qualified(type($input)) - }]; -} - +def CopyOp : NDArray_Op<"copy", [CopyOpInterface, SameOperandsAndResultShape, SameOperandsAndResultElementType]> { -def FromMemRefOp : NDArray_Op<"from_memref", [Pure]> { - let summary = "Convert a builtin memref value to a value of type NDArray"; let description = [{ - Result type possibly adds NDArray annotations. - }]; - - let arguments = (ins AnyMemRef:$input); - let results = (outs NDArray_NDArray); - - let assemblyFormat = [{ - $input attr-dict `:` qualified(type($input)) `->` qualified(type(results)) - }]; -} - + Copies the data from the source to the new result array. -def ToTensorOp : NDArray_Op<"to_tensor", [Pure]> { - let summary = "Convert a NDArray value to a value of MLIR's builtin tensor type"; - let description = [{ - Convert a NDArray value to a value of MLIR's builtin tensor type. - Removes all annotations provided by the environments. + Source and result are expected to have the same element type and shape. + Otherwise, the result is undefined. }]; - let arguments = (ins AnyType:$input); - let results = (outs AnyRankedTensor); + let arguments = (ins Arg:$source); + let results = (outs AnyRankedTensor:$target); let assemblyFormat = [{ - $input attr-dict `:` qualified(type($input)) `->` qualified(type(results)) + $source attr-dict `:` qualified(type($source)) `->` qualified(type($target)) }]; - - let builders = [ - // auto-deduce return type - OpBuilder<(ins "::mlir::Value":$tnsr), [{ - auto mrtyp = mlir::dyn_cast<::imex::ndarray::NDArrayType>(tnsr.getType()).getTensorType(); - assert(mrtyp); - build($_builder, $_state, mrtyp, tnsr); - }]>, - ]; } -def DimOp : NDArray_Op<"dim", [ - DeclareOpInterfaceMethods, - ConditionallySpeculatable, NoMemoryEffect, - ShapedDimOpInterface]> { - let summary = "Dimension index operation"; +def DeleteOp : NDArray_Op<"delete", [ + DeclareOpInterfaceMethods]> { + let summary = "Explicitly delete an nd-array, freeing its memory"; let description = [{ - The `dim` operation takes a array and a dimension operand of type `index`. - It returns the size of the requested dimension of the given array. - If the dimension index is out of bounds the behavior is undefined. + Allow explicitly deleting the memory of an nd-array. It is assumed + that the memory had been allocated by one of nd-array's creation functions. + It must be the last use of the input array. }]; - let arguments = (ins AnyType:$source, Index:$index); - let results = (outs Index:$result); + let arguments = (ins AnyRankedTensor:$input); let assemblyFormat = [{ - $source $index attr-dict `:` qualified(type($source)) `->` qualified(type($result)) - }]; - - let builders = [ - OpBuilder<(ins "::mlir::Value":$source, "int64_t":$index)>, - ]; - - let extraClassDeclaration = [{ - /// Helper function to get the index as a simple integer if it is constant. - std::optional getConstantIndex(); - - /// Interface method of ShapedDimOpInterface: Return the source tensor. - ::mlir::Value getShapedValue() { return getSource(); } - - /// Interface method of ShapedDimOpInterface: Return the dimension. - ::mlir::OpFoldResult getDimension() { return getIndex(); } - - /// Interface method for ConditionallySpeculatable. - ::mlir::Speculation::Speculatability getSpeculatability(); + $input attr-dict `:` qualified(type($input)) }]; - - let hasFolder = 1; - let hasCanonicalizer = 1; } @@ -294,7 +145,7 @@ def SubviewOp : NDArray_OpWithOffsetSizesAndStrides<"subview", [ }]; let arguments = (ins - AnyType:$source, + AnyRankedTensor:$source, Variadic:$offsets, Variadic:$sizes, Variadic:$strides, @@ -302,7 +153,7 @@ def SubviewOp : NDArray_OpWithOffsetSizesAndStrides<"subview", [ DenseI64ArrayAttr:$static_sizes, DenseI64ArrayAttr:$static_strides ); - let results = (outs AnyType:$result); + let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ $source `` @@ -448,13 +299,13 @@ def ExtractSliceOp : NDArray_OpWithOffsetSizesAndStrides<"extract_slice", [ let description = [{ The "extract_slice" operation extract a view from another array as specified by the operation's offsets, sizes and strides arguments. - The returned array is guaranteed to be a view of the source array. + The returned array is not guaranteed to be a view of the source array. This operation is expectecd to eventually lower to tensor.extract_slice. }]; let arguments = (ins - AnyType:$source, + AnyRankedTensor:$source, Variadic:$offsets, Variadic:$sizes, Variadic:$strides, @@ -462,7 +313,7 @@ def ExtractSliceOp : NDArray_OpWithOffsetSizesAndStrides<"extract_slice", [ DenseI64ArrayAttr:$static_sizes, DenseI64ArrayAttr:$static_strides ); - let results = (outs AnyType:$result); + let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ $source `` @@ -575,8 +426,8 @@ def InsertSliceOp : NDArray_OpWithOffsetSizesAndStrides<"insert_slice", [ }]; let arguments = (ins - AnyType:$destination, - AnyType:$source, + AnyRankedTensor:$destination, + AnyRankedTensor:$source, Variadic:$offsets, Variadic:$sizes, Variadic:$strides, @@ -671,8 +522,8 @@ def ImmutableInsertSliceOp : NDArray_OpWithOffsetSizesAndStrides<"immutable_inse }]; let arguments = (ins - AnyType:$destination, - AnyType:$source, + AnyRankedTensor:$destination, + AnyRankedTensor:$source, Variadic:$offsets, Variadic:$sizes, Variadic:$strides, @@ -680,7 +531,7 @@ def ImmutableInsertSliceOp : NDArray_OpWithOffsetSizesAndStrides<"immutable_inse DenseI64ArrayAttr:$static_sizes, DenseI64ArrayAttr:$static_strides ); - let results = (outs NDArray_NDArray:$result); + let results = (outs AnyRankedTensor:$result); let assemblyFormat = [{ $source `into` $destination custom($offsets, $static_offsets) custom($sizes, $static_sizes) custom($strides, $static_strides) attr-dict `:` qualified(type($source)) `into` qualified(type($destination)) @@ -745,71 +596,6 @@ def ImmutableInsertSliceOp : NDArray_OpWithOffsetSizesAndStrides<"immutable_inse } -def LoadOp : NDArray_Op<"load", - [TypesMatchWith<"result type matches element type of 'array'", - "array", "result", - "mlir::cast($_self).getElementType()">]> { - let summary = "array element load operation"; - let description = [{ - The `load` op reads an element from an array specified by an index list. The - output of load is a new value with the same type as the elements of the - array. The arity of indices is the rank of the array (i.e., if the array - loaded from is of rank 3, then 3 indices are required for the load following - the array identifier). - }]; - - let arguments = (ins AnyType:$array, Variadic:$indices); - let results = (outs AnyType:$result); - - let assemblyFormat = [{ - $array `[` $indices `]` attr-dict `:` qualified(type($array)) - }]; - - let builders = [ - OpBuilder<(ins "::mlir::Value":$array, CArg<"::mlir::ValueRange", "{}">:$indices), [{ - auto arrayType = mlir::cast(array.getType()); - $_state.addOperands(array); - $_state.addOperands(indices); - $_state.types.push_back(arrayType.getElementType()); - }]>]; - -} - - -def CopyOp : NDArray_Op<"copy", [CopyOpInterface, SameOperandsAndResultShape, SameOperandsAndResultElementType]> { - - let description = [{ - Copies the data from the source to the new result array. - - Source and result are expected to have the same element type and shape. - Otherwise, the result is undefined. - }]; - - let arguments = (ins Arg:$source); - let results = (outs NDArray_NDArray:$target); - - let assemblyFormat = [{ - $source attr-dict `:` qualified(type($source)) `->` qualified(type($target)) - }]; -} - - -def CastOp : NDArray_Op<"cast", [ - DeclareOpInterfaceMethods, - Pure]> { - let summary = "Cast a NDArray to a compatible NDArray type"; - - let arguments = (ins NDArray_NDArray:$source); - let results = (outs NDArray_NDArray:$destination); - - let assemblyFormat = [{ - $source attr-dict `:` qualified(type($source)) `to` qualified(type($destination)) - }]; - - let hasCanonicalizer = 1; -} - - def LinSpaceOp : NDArray_Op<"linspace", [Pure]> { let summary = "Returns evenly spaced numbers over a specified interval."; let description = [{ @@ -818,7 +604,7 @@ def LinSpaceOp : NDArray_Op<"linspace", [Pure]> { }]; let arguments = (ins AnyType:$start, AnyType:$stop, AnyType:$num, UnitAttr:$endpoint); - let results = (outs NDArray_NDArray); + let results = (outs AnyRankedTensor); let assemblyFormat = [{ $start $stop $num (`true` $endpoint^):(`false`)? attr-dict `:` `(` type(operands) `)` `->` qualified(type(results)) @@ -831,7 +617,7 @@ def LinSpaceOp : NDArray_Op<"linspace", [Pure]> { CArg<"::mlir::ArrayRef<::mlir::Attribute>", "{}">:$environments), [{ auto dt = toMLIR($_builder, dtype); build($_builder, $_state, - ::imex::ndarray::NDArrayType::get(getShapeFromValues(num), dt, environments), + ::mlir::RankedTensorType::get(getShapeFromValues(num), dt), start, stop, num, endpoint); }]>, ]; @@ -841,11 +627,11 @@ def LinSpaceOp : NDArray_Op<"linspace", [Pure]> { def CreateOp : NDArray_Op<"create", [Pure, AttrSizedOperandSegments]> { - let summary = "Returns a new NDArray having a specified shape and type and optionally filled with a value."; + let summary = "Returns a new tensor having a specified shape and type and optionally filled with a value."; let arguments = (ins Variadic:$shape, Optional:$value); // result is a ndarray - let results = (outs NDArray_NDArray); + let results = (outs AnyRankedTensor); let assemblyFormat = [{ $shape oilist(`value` $value) attr-dict `:` `(` type(operands) `)` `->` qualified(type(results)) @@ -857,7 +643,7 @@ def CreateOp : NDArray_Op<"create", [Pure, AttrSizedOperandSegments]> { CArg<"::mlir::ArrayRef<::mlir::Attribute>", "{}">:$environments), [{ auto dt = toMLIR($_builder, dtype); build($_builder, $_state, - ::imex::ndarray::NDArrayType::get(getShapeFromValues(shape), dt, environments), + ::mlir::RankedTensorType::get(getShapeFromValues(shape), dt), shape, value); }]>, ]; @@ -882,93 +668,15 @@ def ReshapeOp : NDArray_Op<"reshape", []> { } -def EWBinOp : NDArray_Op<"ewbin", []> { - let summary = "Apply elementwise binary operation"; - let description = [{ - Apply the `op(lhs[i], rhs[i])` on all elements `i` and return a new ndarray. - Apply the broadcasting and type promotaions rules of the array-API - to operator and result types. - }]; - - // ewbin takes 2 NDArrayType operands: lhs and rhs - let arguments = (ins AnyAttr:$op, AnyType:$lhs, AnyType:$rhs); - // result is a ndarray - let results = (outs AnyType); - - let assemblyFormat = [{ - $lhs `,` $rhs attr-dict `:` `(`qualified(type(operands))`)` `->` qualified(type(results)) - }]; - - let hasCanonicalizer = 1; -} - - -def EWUnyOp : NDArray_Op<"ewuny", []> { - let summary = "Apply elementwise unary operation"; - let description = [{ - Apply the `op(src[i])` on all elements `i` and return a new ndarray. - }]; - - // ewuny takes 1 operand (NDArrayType) and one attribute (unary operation) - let arguments = (ins AnyAttr:$op, AnyType:$src); - // result is a ndarray - let results = (outs AnyType); - - let assemblyFormat = [{ - $src attr-dict `:` qualified(type($src)) `->` qualified(type(results)) - }]; - - let hasCanonicalizer = 1; -} - - -def ReductionOp : NDArray_Op<"reduction", []> { - let summary = "Apply reduction operation"; - let description = [{ - Apply the reduction operation `op` over all elements of `input`. - The produced result is a 0-dim tensor with the same dtype as `input`. - }]; - - // reduction takes 1 operand (NDArrayType) and one attribute (reduction operation) - let arguments = (ins AnyAttr:$op, AnyType:$input); - // result is a ndarray - let results = (outs NDArray_NDArray); - - let assemblyFormat = [{ - $input attr-dict `:` qualified(type($input)) `->` qualified(type(results)) - }]; -} - def CastElemTypeOp: NDArray_Op<"cast_elemtype", [Pure]> { let summary = "Cast array from one element type to another"; let arguments = (ins AnyType:$input, OptionalAttr:$copy); - let results = (outs NDArray_NDArray:$output); + let results = (outs AnyRankedTensor:$output); let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `to` qualified(type($output))"; let hasCanonicalizer = 1; } -def PermuteDimsOp : NDArray_Op<"permute_dims", []> { - let summary = "Permutes the axes (dimensions) of an array to a new array."; - let description = [{ - Permutes the axes (dimensions) of an array. - The output array is a new array. - }]; - - let arguments = (ins - NDArray_NDArray:$source, - DenseI64ArrayAttr:$axes - ); - let results = (outs NDArray_NDArray); - - let assemblyFormat = [{ - $source $axes attr-dict `:` qualified(type($source)) `->` qualified(type(results)) - }]; - - let hasCanonicalizer = 1; - let hasVerifier = 1; -} - #endif // _NDARRAY_OPS_TD_INCLUDED_ diff --git a/include/imex/Dialect/NDArray/Transforms/Passes.h b/include/imex/Dialect/NDArray/Transforms/Passes.h index 0b90c1b61..99db67de2 100644 --- a/include/imex/Dialect/NDArray/Transforms/Passes.h +++ b/include/imex/Dialect/NDArray/Transforms/Passes.h @@ -28,19 +28,10 @@ class RewritePatternSet; namespace imex { -//===----------------------------------------------------------------------===// -/// NDArray passes. -//===----------------------------------------------------------------------===// - -/// Create a NDArrayDist pass -std::unique_ptr<::mlir::Pass> createNDArrayDistPass(); - -/// Populate the given list with patterns which add Dist-Ops to NDArray ops -void populateNDArrayDistPatterns(::mlir::LLVMTypeConverter &converter, - ::mlir::RewritePatternSet &patterns); - /// Create a AddGPURegions pass std::unique_ptr<::mlir::Pass> createAddGPURegionsPass(); +/// Create a ShardingCoalesce pass +std::unique_ptr<::mlir::Pass> createCoalesceShardOpsPass(); #define GEN_PASS_DECL #include diff --git a/include/imex/Dialect/NDArray/Transforms/Passes.td b/include/imex/Dialect/NDArray/Transforms/Passes.td index f064f8700..b53374ee0 100644 --- a/include/imex/Dialect/NDArray/Transforms/Passes.td +++ b/include/imex/Dialect/NDArray/Transforms/Passes.td @@ -17,34 +17,6 @@ include "mlir/Pass/PassBase.td" -//===----------------------------------------------------------------------===// -// NDArrayDist -//===----------------------------------------------------------------------===// - -def NDArrayDist : Pass<"ndarray-dist"> { - let summary = "Use Dist-Ops to enable distributed NDArray Ops"; - let description = [{ - Transforms NDArray Ops into a sequence of operations to enable compute-follows-data - for distributed memory. Using the Dist dialect for disribution operations. - - #### Output IR - - Dist dialect - - NDArray dialect - - Linalg dialect - - Arith dialect - }]; - let constructor = "imex::createNDArrayDistPass()"; - let dependentDialects = ["::imex::ndarray::NDArrayDialect", - "::imex::dist::DistDialect", - "::imex::distruntime::DistRuntimeDialect", - "::mlir::arith::ArithDialect", - "::mlir::bufferization::BufferizationDialect", - "::mlir::linalg::LinalgDialect", - "::mlir::tensor::TensorDialect", - "::mlir::memref::MemRefDialect"]; - let options = []; -} - def AddGPURegions : Pass<"add-gpu-regions"> { let summary = "Add RegionOps around NDArray Ops where applicable."; let description = [{ @@ -57,4 +29,22 @@ def AddGPURegions : Pass<"add-gpu-regions"> { let options = []; } +//===----------------------------------------------------------------------===// +// CoalesceShardOps +//===----------------------------------------------------------------------===// + +def CoalesceShardOps : Pass<"coalesce-shard-ops"> { + let summary = "Coalesce shard operations from mesh dialect."; + let description = [{ + Combine shard ops which would lead to resharding of tensors. + This pass handles coalesing of shard ops which annotate ndarray's + subview operations. + }]; + let constructor = "imex::createCoalesceShardOpsPass()"; + let dependentDialects = ["::mlir::mesh::MeshDialect", + "::mlir::arith::ArithDialect", + "::mlir::tensor::TensorDialect", + "::mlir::memref::MemRefDialect"]; +} + #endif // _NDARRAY_PASSES_TD_INCLUDED_ diff --git a/include/imex/Dialect/NDArray/Utils/Utils.h b/include/imex/Dialect/NDArray/Utils/Utils.h index 2fef4ca74..1bdbcebaa 100644 --- a/include/imex/Dialect/NDArray/Utils/Utils.h +++ b/include/imex/Dialect/NDArray/Utils/Utils.h @@ -102,14 +102,14 @@ inline ::mlir::Value createDType(::mlir::Location &loc, template auto createShapeOf(::mlir::Location loc, ::mlir::OpBuilder &builder, ::mlir::Value lPTnsr) { - auto arType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(lPTnsr.getType()); + auto arType = mlir::dyn_cast<::mlir::RankedTensorType>(lPTnsr.getType()); assert(arType); auto rank = arType.getRank(); T dims; for (int64_t i = 0; i < rank; ++i) { dims.emplace_back( - builder.createOrFold<::imex::ndarray::DimOp>(loc, lPTnsr, i)); + builder.createOrFold<::mlir::tensor::DimOp>(loc, lPTnsr, i)); } return dims; @@ -118,10 +118,9 @@ auto createShapeOf(::mlir::Location loc, ::mlir::OpBuilder &builder, // convert an unranked memref from a NDArray inline ::mlir::Value mkURMemRef(::mlir::Location loc, ::mlir::OpBuilder &builder, ::mlir::Value src) { - auto srcArType = mlir::cast<::imex::ndarray::NDArrayType>(src.getType()); - auto bMRTyp = srcArType.getMemRefType(); - auto bTensor = builder.create<::imex::ndarray::ToTensorOp>(loc, src); - auto bMRef = createToMemRef(loc, builder, bTensor, bMRTyp); + auto srcArType = mlir::cast<::mlir::RankedTensorType>(src.getType()); + auto bMRTyp = getMemRefType(srcArType); + auto bMRef = createToMemRef(loc, builder, src, bMRTyp); return createUnrankedMemRefCast(builder, loc, bMRef); } diff --git a/include/imex/Dialect/Region/Transforms/Passes.td b/include/imex/Dialect/Region/Transforms/Passes.td index 876dd6f55..f66943421 100644 --- a/include/imex/Dialect/Region/Transforms/Passes.td +++ b/include/imex/Dialect/Region/Transforms/Passes.td @@ -18,7 +18,7 @@ include "mlir/Pass/PassBase.td" //===----------------------------------------------------------------------===// -// DistCoalesce +// RegionBufferize //===----------------------------------------------------------------------===// def RegionBufferize : Pass<"region-bufferize"> { diff --git a/include/imex/InitIMEXDialects.h b/include/imex/InitIMEXDialects.h index e17b5cc80..1b3e30f65 100644 --- a/include/imex/InitIMEXDialects.h +++ b/include/imex/InitIMEXDialects.h @@ -18,11 +18,10 @@ #include #include -#include #include #include -#include #include +#include #include #include @@ -31,8 +30,7 @@ namespace imex { /// Add all the IMEX dialects to the provided registry. inline void registerAllDialects(::mlir::DialectRegistry ®istry) { // clang-format off - registry.insert<::imex::dist::DistDialect, - ::imex::distruntime::DistRuntimeDialect, + registry.insert<::imex::distruntime::DistRuntimeDialect, ::imex::ndarray::NDArrayDialect, ::imex::region::RegionDialect, ::imex::xetile::XeTileDialect, diff --git a/include/imex/InitIMEXPasses.h b/include/imex/InitIMEXPasses.h index 3d3afb87f..cd29b15f1 100644 --- a/include/imex/InitIMEXPasses.h +++ b/include/imex/InitIMEXPasses.h @@ -17,7 +17,6 @@ #include // #include -#include #include #include #include @@ -45,7 +44,6 @@ inline void registerAllPasses() { // Dialect passes registerNDArrayPasses(); - registerDistPasses(); registerDistRuntimePasses(); registerRegionPasses(); registerXeTilePasses(); diff --git a/include/imex/Utils/PassUtils.h b/include/imex/Utils/PassUtils.h index 45ed12d01..c8b4404b6 100644 --- a/include/imex/Utils/PassUtils.h +++ b/include/imex/Utils/PassUtils.h @@ -147,6 +147,11 @@ extern ::mlir::MemRefType getMemRefType(::mlir::MLIRContext *ctxt, ::mlir::Type elType, bool strided = true); +inline ::mlir::MemRefType getMemRefType(::mlir::RankedTensorType ttype) { + return getMemRefType(ttype.getContext(), ttype.getShape(), + ttype.getElementType()); +} + /// Create a 1d MemRef alloc with given size and elType extern ::mlir::Value createAllocMR(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::Type elType, @@ -329,7 +334,7 @@ inline std::string mkTypedFunc(const ::std::string &base, ::mlir::Type elType) { // helper for sorting operations struct opOrderCmp { - opOrderCmp(::mlir::DominanceInfo &dom) : _dom(dom){}; + opOrderCmp(::mlir::DominanceInfo &dom) : _dom(dom) {}; ::mlir::DominanceInfo &_dom; bool operator()(::mlir::Operation *i, ::mlir::Operation *j) const { if (_dom.dominates(i, j)) { diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index b52337e57..5e31e8c11 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -1,4 +1,3 @@ -add_subdirectory(DistToStandard) add_subdirectory(NDArrayToLinalg) add_subdirectory(DropRegions) add_subdirectory(GPUToSPIRV) diff --git a/lib/Conversion/DistToStandard/CMakeLists.txt b/lib/Conversion/DistToStandard/CMakeLists.txt deleted file mode 100644 index 7f3abf29d..000000000 --- a/lib/Conversion/DistToStandard/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_imex_conversion_library(IMEXDistToStandard - DistToStandard.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/imex/Conversion/DistToStandard - - DEPENDS - IMEXConversionPassIncGen - - LINK_LIBS PUBLIC - IMEXNDArrayDialect - MLIRLinalgDialect -) diff --git a/lib/Conversion/DistToStandard/DistToStandard.cpp b/lib/Conversion/DistToStandard/DistToStandard.cpp deleted file mode 100644 index 412235f88..000000000 --- a/lib/Conversion/DistToStandard/DistToStandard.cpp +++ /dev/null @@ -1,1897 +0,0 @@ -//===- DistToStandard.cpp - DistToStandard conversion ----------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements the DistToStandard conversion, converting the Dist -/// dialect to standard dialects (including DistRuntime and NDArray). -/// -//===----------------------------------------------------------------------===// - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace imex { -#define GEN_PASS_DEF_CONVERTDISTTOSTANDARD -#include "imex/Conversion/Passes.h.inc" -} // namespace imex - -using ::imex::ndarray::createDType; -using ::imex::ndarray::createShapeOf; - -namespace imex { -namespace dist { -namespace { - -// ******************************* -// ***** Individual patterns ***** -// ******************************* - -/// Rewriting ::imex::ndarray::LinSpaceOp to get a distributed linspace if -/// applicable. -struct LinSpaceOpConverter - : public ::mlir::OpConversionPattern<::imex::ndarray::LinSpaceOp> { - using ::mlir::OpConversionPattern< - ::imex::ndarray::LinSpaceOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::LinSpaceOp op, - ::imex::ndarray::LinSpaceOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto retArType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getType()); - if (!retArType) - return ::mlir::failure(); - auto dEnv = getDistEnv(retArType); - // nothing to do if not distributed - if (!dEnv) - return ::mlir::failure(); - - auto team = dEnv.getTeam(); - auto start = op.getStart(); - auto stop = op.getStop(); - auto count = op.getNum(); - bool endpoint = op.getEndpoint(); - - if (!(start.getType().isIntOrIndexOrFloat() && - stop.getType().isIntOrIndexOrFloat() && - count.getType().isIntOrIndex() && retArType)) { - return ::mlir::failure(); - } // FIXME type promotion - - // cast types and get step - auto elTyp = retArType.getElementType(); - count = createIndexCast(loc, rewriter, count); - auto bw = elTyp.isIndex() ? 64 : elTyp.getIntOrFloatBitWidth(); - ::mlir::Type cType = - bw > 32 ? rewriter.getF64Type() - : (bw > 16 ? rewriter.getF32Type() : rewriter.getF16Type()); - start = createCast(loc, rewriter, start, cType); - stop = createCast(loc, rewriter, stop, cType); - auto step = - createStepLinSpace(rewriter, loc, start, stop, count, endpoint, cType); - - // get number of procs and prank - auto nProcs = createNProcs(loc, rewriter, team); - auto pRank = createPRank(loc, rewriter, team); - - // get local shape and offsets - auto lPart = rewriter.create<::imex::dist::DefaultPartitionOp>( - loc, nProcs, pRank, ::mlir::ValueRange{count}); - auto lShape = lPart.getLShape(); - auto lOffs = lPart.getLOffsets(); - - // use local shape and offset to compute local linspace - auto off = createCast(loc, rewriter, lOffs[0], cType); - auto lSz = createCast(loc, rewriter, lShape[0], cType); - - start = rewriter.createOrFold<::mlir::arith::AddFOp>( - loc, rewriter.createOrFold<::mlir::arith::MulFOp>(loc, step, off), - start); - stop = rewriter.createOrFold<::mlir::arith::AddFOp>( - loc, rewriter.createOrFold<::mlir::arith::MulFOp>(loc, step, lSz), - start); - - // finally create local linspace - auto res = rewriter.create<::imex::ndarray::LinSpaceOp>( - loc, start, stop, lShape[0], false, ndarray::fromMLIR(elTyp), - getNonDistEnvs(retArType)); - - rewriter.replaceOp(op, createDistArray(loc, rewriter, team, {op.getNum()}, - lOffs, res.getResult())); - return ::mlir::success(); - } -}; - -/// Rewriting ::imex::ndarray::CreateOp to get a distributed CreateOp if -/// applicable. Create global, distributed output array as defined by operands. -/// The local partition (e.g. a RankedTensor) are wrapped in a -/// non-distributed NDArray and re-applied to CreateOp. -/// op gets replaced with global distributed array -struct CreateOpConverter - : public ::mlir::OpConversionPattern<::imex::ndarray::CreateOp> { - using ::mlir::OpConversionPattern< - ::imex::ndarray::CreateOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::CreateOp op, - ::imex::ndarray::CreateOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto retArType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getType()); - if (!retArType) - return ::mlir::failure(); - auto dEnv = getDistEnv(retArType); - // nothing to do if not distributed - if (!dEnv) - return ::mlir::failure(); - - auto team = dEnv.getTeam(); - auto gShape = op.getShape(); - // get local shape and offsets - auto lPart = createDefaultPartition(loc, rewriter, team, gShape); - - // finally create local array - auto arres = rewriter.create<::imex::ndarray::CreateOp>( - loc, lPart.getLShape(), ndarray::fromMLIR(retArType.getElementType()), - op.getValue(), getNonDistEnvs(retArType)); - - rewriter.replaceOp(op, - createDistArray(loc, rewriter, team, gShape, - lPart.getLOffsets(), arres.getResult())); - return ::mlir::success(); - } -}; - -/// Convert a CopyOp on a distributed array to CopyOps on the local data. -struct CopyOpConverter - : public ::mlir::OpConversionPattern<::imex::ndarray::CopyOp> { - using ::mlir::OpConversionPattern< - ::imex::ndarray::CopyOp>::OpConversionPattern; - - /// Initialize the pattern. - void initialize() { - /// Signal that this pattern safely handles recursive application. - setHasBoundedRewriteRecursion(); - } - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::CopyOp op, - ::imex::ndarray::CopyOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - - auto src = op.getSource(); - auto srcDistType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType()); - auto resDistType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getResult().getType()); - // return failure if wrong ops or not distributed - if (!srcDistType || !isDist(srcDistType) || !resDistType || - !isDist(resDistType)) { - return ::mlir::failure(); - } - // FIXME: check if part shapes are compatible - - auto loc = op.getLoc(); - auto partTypes = getPartsTypes(resDistType); - auto lParts = createPartsOf(loc, rewriter, src); - auto lOffsets = createLocalOffsetsOf(loc, rewriter, src); - - // apply CopyOp to all parts - ::imex::ValVec resParts; - for (auto i = 0u; i < lParts.size(); ++i) { - auto partOp = rewriter.create<::imex::ndarray::CopyOp>(loc, partTypes[i], - lParts[i]); - resParts.emplace_back(partOp.getResult()); - } - - // get global shape - auto gShape = resDistType.getShape(); - // and init our new dist array - rewriter.replaceOp(op, createDistArray(loc, rewriter, - getDistEnv(srcDistType).getTeam(), - gShape, lOffsets, resParts)); - - return ::mlir::success(); - } -}; - -/// Convert a DeleteOp on a distributed array to DeleteOps on the local data. -struct DeleteOpConverter - : public ::mlir::OpConversionPattern<::imex::ndarray::DeleteOp> { - using ::mlir::OpConversionPattern< - ::imex::ndarray::DeleteOp>::OpConversionPattern; - - /// Initialize the pattern. - void initialize() { - /// Signal that this pattern safely handles recursive application. - setHasBoundedRewriteRecursion(); - } - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::DeleteOp op, - ::imex::ndarray::DeleteOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - - auto src = op.getInput(); - auto srcDistType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType()); - // return failure if wrong ops or not distributed - if (!srcDistType || !isDist(srcDistType)) { - return ::mlir::failure(); - } - - auto loc = op.getLoc(); - auto lParts = createPartsOf(loc, rewriter, src); - - // apply DeleteOp to all parts - for (auto p : lParts) { - auto newOp = rewriter.create<::imex::ndarray::DeleteOp>(loc, p); - newOp->setAttrs(adaptor.getAttributes()); - } - - rewriter.eraseOp(op); - - return ::mlir::success(); - } -}; - -// extract RankedTensor and create ::imex::dist::AllReduceOp -inline ::imex::distruntime::AllReduceOp -createAllReduce(::mlir::Location &loc, ::mlir::OpBuilder &builder, - ::mlir::Attribute op, ::mlir::Value ndArray) { - auto arType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(ndArray.getType()); - assert(arType); - auto lArray = builder.create<::imex::ndarray::ToTensorOp>(loc, ndArray); - auto lMRef = createToMemRef(loc, builder, lArray, arType.getMemRefType()); - return builder.create<::imex::distruntime::AllReduceOp>(loc, op, lMRef); -} - -/// Rewrite ::imex::ndarray::ReductionOp to get a distributed -/// reduction if operand is distributed. -/// Create global, distributed 0d output array. -/// The local partitions of operand (e.g. RankedTensor) is wrapped in -/// non-distributed NDArray and re-applied to reduction. -/// The result is then applied to a distributed allreduce. -/// op gets replaced with global distributed array -struct ReductionOpConverter - : public ::mlir::OpConversionPattern<::imex::ndarray::ReductionOp> { - using ::mlir::OpConversionPattern< - ::imex::ndarray::ReductionOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::ReductionOp op, - ::imex::ndarray::ReductionOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - // FIXME reduction over individual dimensions is not supported - auto loc = op.getLoc(); - auto inp = op.getInput(); - auto inpDistTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(inp.getType()); - // nothing to do if not distributed - if (!inpDistTyp || !isDist(inpDistTyp)) - return ::mlir::failure(); - - // Local reduction - auto parts = createPartsOf(loc, rewriter, inp); - auto local = parts.size() == 1 ? parts[0] : parts[1]; - auto retArType = cloneAsNonDist(op.getType()); - auto redArray = rewriter.create<::imex::ndarray::ReductionOp>( - loc, retArType, op.getOp(), local); - // global reduction - (void)createAllReduce(loc, rewriter, op.getOp(), redArray); - - // init our new dist array - // FIXME result shape is 0d always - rewriter.replaceOp(op, createDistArray(loc, rewriter, - getDistEnv(inpDistTyp).getTeam(), - ::mlir::SmallVector(), {}, - redArray.getResult())); - return ::mlir::success(); - } -}; - -/// Rewriting ::imex::ndarray::ToTensorOp -/// Get NDArray from distributed array and apply to ToTensorOp. -struct ToTensorOpConverter - : public ::mlir::OpConversionPattern<::imex::ndarray::ToTensorOp> { - using ::mlir::OpConversionPattern< - ::imex::ndarray::ToTensorOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::ToTensorOp op, - ::imex::ndarray::ToTensorOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - // get input - auto inpArTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getInput().getType()); - if (!inpArTyp || !isDist(inpArTyp)) { - return ::mlir::failure(); - } - auto parts = createPartsOf(loc, rewriter, op.getInput()); - auto part = parts.size() == 1 ? parts[0] : parts[1]; - rewriter.replaceOpWithNewOp<::imex::ndarray::ToTensorOp>(op, part); - return ::mlir::success(); - } -}; - -/// Convert a global dist::SubviewOP to ndarray::SubviewOp on the local data. -/// Computes overlap of slice, local parts and target. -/// Even though the op accepts static offs/sizes all computation -/// is done on values - only static dim-sizes of 1 are properly propagated. -/// Static strides are always propagated to NDArray. -struct SubviewOpConverter - : public ::mlir::OpConversionPattern<::imex::dist::SubviewOp> { - using ::mlir::OpConversionPattern< - ::imex::dist::SubviewOp>::OpConversionPattern; - - /// Initialize the pattern. - void initialize() { - /// Signal that this pattern safely handles recursive application. - setHasBoundedRewriteRecursion(); - } - - ::mlir::LogicalResult - matchAndRewrite(::imex::dist::SubviewOp op, - ::imex::dist::SubviewOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - // get input and type - auto src = op.getSource(); - auto inpDistTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType()); - auto resDistType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getResult().getType()); - if (!inpDistTyp || !isDist(inpDistTyp) || !resDistType || - !isDist(resDistType)) { - return ::mlir::failure(); - } - - auto loc = op.getLoc(); - // Get the local part of the global slice, team, rank, offsets - auto _slcOffs = adaptor.getOffsets(); - auto _slcSizes = adaptor.getSizes(); - auto _slcStrides = adaptor.getStrides(); - auto sSlcOffs = adaptor.getStaticOffsets(); - auto sSlcSizes = adaptor.getStaticSizes(); - auto sSlcStrides = adaptor.getStaticStrides(); - auto tOffs = adaptor.getTargetOffsets(); - auto tSizes = adaptor.getTargetSizes(); - auto rank = std::max(sSlcOffs.size(), _slcOffs.size()); - bool hasTarget = tOffs.size() > 0; - - // get offs, sizes strides as values - auto slcOffs = getMixedAsValues(loc, rewriter, _slcOffs, sSlcOffs); - auto slcSizes = getMixedAsValues(loc, rewriter, _slcSizes, sSlcSizes); - auto slcStrides = getMixedAsValues(loc, rewriter, _slcStrides, sSlcStrides); - - ::imex::ValVec lViews, lVOffsets; - auto srcParts = createPartsOf(loc, rewriter, src); - ::imex::ValVec lOffs = createLocalOffsetsOf(loc, rewriter, src); - ::mlir::SmallVector shift(rank, easyIdx(loc, rewriter, 0)); - - // if a target is provided, crop slice to given target - if (hasTarget) { - for (auto i = 0u; i < rank; ++i) { - // remember the target offset as we need to "shift back" for the local - // offset - shift[i] = easyIdx(loc, rewriter, tOffs[i]); - slcOffs[i] = (easyIdx(loc, rewriter, slcOffs[i]) + - (shift[i] * easyIdx(loc, rewriter, slcStrides[i]))) - .get(); - slcSizes[i] = tSizes[i]; - } - } - - for (auto lPart : srcParts) { - ::imex::ValVec lSlcOffsets; - auto pShape = createShapeOf(loc, rewriter, lPart); - // Compute local part - auto pOverlap = createOverlap(loc, rewriter, lOffs, pShape, slcOffs, - slcSizes, slcStrides); - auto pOffsets = std::get<0>(pOverlap); - auto pSizes_ = std::get<1>(pOverlap); - - if (lVOffsets.size() == 0) { - lVOffsets = std::get<2>(pOverlap); - // "shift back" the cropped target offset - for (auto i = 0u; i < rank; ++i) { - lVOffsets[i] = - (easyIdx(loc, rewriter, lVOffsets[i]) + shift[i]).get(); - } - } - - // get static size==1 and strides back - ::mlir::SmallVector<::mlir::OpFoldResult> pOffs, pStrides, pSizes; - for (size_t i = 0; i < rank; ++i) { - auto pOff_ = easyIdx(loc, rewriter, pOffsets[i]); - auto lOff_ = easyIdx(loc, rewriter, lOffs[i]); - auto lShp_ = easyIdx(loc, rewriter, pShape[i]); - auto lOff = (pOff_ - lOff_).min(lShp_); - pOffs.emplace_back(lOff.get()); - auto s = sSlcStrides[i]; - pStrides.emplace_back( - ::mlir::ShapedType::isDynamic(s) - ? ::mlir::OpFoldResult{slcStrides[i]} - : ::mlir::OpFoldResult{rewriter.getIndexAttr(s)}); - // this might break broadcasting since size=1 is no longer static - pSizes.emplace_back(pSizes_[i]); - } - - // create local view - lViews.emplace_back(rewriter.create<::imex::ndarray::SubviewOp>( - loc, - mlir::dyn_cast<::mlir::RankedTensorType>(mlir::dyn_cast<::imex::ndarray::NDArrayType>(lPart.getType()) - .cloneWithDynDims()), - lPart, pOffs, pSizes, pStrides)); - - // update local offset for next part - lOffs[0] = rewriter.createOrFold<::mlir::arith::AddIOp>(loc, lOffs[0], - pShape[0]); - } - - // init our new dist array - auto dEnv = getDistEnv(resDistType); - rewriter.replaceOp(op, createDistArray(loc, rewriter, dEnv.getTeam(), - slcSizes, lVOffsets, lViews)); - return ::mlir::success(); - } -}; - -/// Convert a global dist::InsertSliceOp to ndarray::InsertSliceOp on the local -/// data. Assumes that the input is properly partitioned: the target part or if -/// none provided the default partitioning. -struct InsertSliceOpConverter - : public ::mlir::OpConversionPattern<::imex::ndarray::InsertSliceOp> { - using ::mlir::OpConversionPattern< - ::imex::ndarray::InsertSliceOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::InsertSliceOp op, - ::imex::ndarray::InsertSliceOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - - auto destArType = mlir::dyn_cast<::imex::ndarray::NDArrayType>( - op.getDestination().getType()); - auto srcArType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getSource().getType()); - if (!destArType || !isDist(destArType) || !srcArType || - !isDist(srcArType)) { - return ::mlir::failure(); - } - - auto loc = op.getLoc(); - auto dest = op.getDestination(); - auto slcOffs = ::mlir::getMixedValues(adaptor.getStaticOffsets(), - adaptor.getOffsets(), rewriter); - auto slcSizes = ::mlir::getMixedValues(adaptor.getStaticSizes(), - adaptor.getSizes(), rewriter); - auto slcStrides = ::mlir::getMixedValues(adaptor.getStaticStrides(), - adaptor.getStrides(), rewriter); - - auto srcRank = srcArType.getRank(); - auto srcParts = createPartsOf(loc, rewriter, op.getSource()); - unsigned ownPartIdx = srcParts.size() == 1 ? 0 : 1; - - // the destination is assumed to be contiguous always - auto destParts = createPartsOf(loc, rewriter, dest); - ::imex::ValVec destOffs = createLocalOffsetsOf(loc, rewriter, dest); - ::mlir::Value lDest; - // get the local part - if (destParts.size() == 1) { - lDest = destParts[0]; - } else { - lDest = destParts[1]; - // if it's the second part, we need to update the offset - auto pSizes = createShapeOf(loc, rewriter, destParts[0]); - destOffs[0] = rewriter.createOrFold<::mlir::arith::AddIOp>( - loc, destOffs[0], pSizes[0]); - } - - // The parts in src are in order and together form a uniform view. - // The view must have the same shape as the local part of dest. - // We can just insert one part of src after the other. - // We only need to update the off into dest's local part. - - auto destSizes = createShapeOf(loc, rewriter, lDest); - auto destOverlap = createOverlap(loc, rewriter, destOffs, destSizes, - slcOffs, slcSizes, slcStrides); - auto lOffs = std::get<0>(destOverlap); - auto lSizes = std::get<1>(destOverlap); - auto lo0 = - easyIdx(loc, rewriter, lOffs[0]) - easyIdx(loc, rewriter, destOffs[0]); - lOffs[0] = lo0.get(); - - if (srcRank) { - for (auto srcPart : srcParts) { - auto ary = mlir::cast<::imex::ndarray::NDArrayType>(srcPart.getType()); - if (ary.hasZeroSize()) { - continue; - } - - // the shape of the src part is also used for Sizes in insert_slice - auto srcSizes = - createShapeOf<::mlir::SmallVector<::mlir::OpFoldResult>>( - loc, rewriter, srcPart); - - // and finally insert this view into lDest - rewriter.create<::imex::ndarray::InsertSliceOp>( - loc, lDest, srcPart, lOffs, srcSizes, slcStrides); - - // for the next src part we have to move the offset in our lDest - lo0 = lo0 + (easyIdx(loc, rewriter, srcSizes[0]) * - easyIdx(loc, rewriter, slcStrides[0])); - lOffs[0] = lo0.get(); - } - } else { - // src is a 0d array - auto sz = easyIdx(loc, rewriter, lSizes[0]); - auto zero = easyIdx(loc, rewriter, 0); - rewriter.create<::mlir::scf::IfOp>( - loc, sz.sgt(zero).get(), - [&](::mlir::OpBuilder &builder, ::mlir::Location loc) { - builder.create<::imex::ndarray::InsertSliceOp>( - loc, lDest, srcParts[ownPartIdx], lOffs, lSizes, slcStrides); - builder.create<::mlir::scf::YieldOp>(loc); - }); - } - - rewriter.eraseOp(op); - - return ::mlir::success(); - } -}; - -/// Convert a global ndarray::ReshapeOp on a distributed array -/// to ndarray::ReshapeOp on the local data. -/// If needed, adds a repartition op. -/// The local partition (e.g. a RankedTensor) is wrapped in a -/// non-distributed NDArray and re-applied to ReshapeOp. -/// op gets replaced with global distributed array -struct ReshapeOpConverter - : public ::mlir::OpConversionPattern<::imex::ndarray::ReshapeOp> { - using ::mlir::OpConversionPattern< - ::imex::ndarray::ReshapeOp>::OpConversionPattern; - - /// Initialize the pattern. - void initialize() { - /// Signal that this pattern safely handles recursive application. - setHasBoundedRewriteRecursion(); - } - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::ReshapeOp op, - ::imex::ndarray::ReshapeOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - - auto src = op.getSource(); - auto srcDistType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType()); - auto retDistType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getResult().getType()); - if (!(srcDistType && isDist(srcDistType) && retDistType && - isDist(retDistType))) { - return ::mlir::failure(); - } - - auto loc = op.getLoc(); - auto elType = srcDistType.getElementType(); - auto dEnv = getDistEnv(srcDistType); - auto ngShape = adaptor.getShape(); - auto gShape = createGlobalShapeOf(loc, rewriter, src); - auto lParts = createPartsOf(loc, rewriter, src); - auto lArray = lParts.size() == 1 ? lParts[0] : lParts[1]; - auto lOffs = createLocalOffsetsOf(loc, rewriter, src); - - // Repartitioning is needed if any of the partitions' size is not a multiple - // of the new chunksize. - // For now we always copy. some initial check existed in rev 3a0b97825382b - - assert(adaptor.getCopy().value_or(1) != 0 || - (false && "Distributed reshape currently requires copying")); - - // FIXME: Check return type: Check that static sizes are the same as the - // default part sizes - - auto team = dEnv.getTeam(); - auto nPart = createDefaultPartition(loc, rewriter, team, ngShape); - auto nlOffs = nPart.getLOffsets(); - auto nlShape = nPart.getLShape(); - auto shp = getShapeFromValues(nlShape); - auto lRetType = ::imex::ndarray::NDArrayType::get( - shp, elType, getNonDistEnvs(retDistType)); - - // call the idt runtime - auto htype = ::imex::distruntime::AsyncHandleType::get(getContext()); - auto nlArray = rewriter.create<::imex::distruntime::CopyReshapeOp>( - loc, ::mlir::TypeRange{htype, lRetType}, team, lArray, gShape, lOffs, - ngShape, nlOffs, nlShape); - (void)rewriter.create<::imex::distruntime::WaitOp>(loc, - nlArray.getHandle()); - // finally init dist array - rewriter.replaceOp( - op, createDistArray(loc, rewriter, team, ngShape, nlOffs, - ::mlir::ValueRange{nlArray.getNlArray()})); - - return ::mlir::success(); - } -}; - -/// Convert a global dist::EWBinOp to ndarray::EWBinOp on the local data. -/// Assumes that the partitioning of the inputs are properly aligned. -struct EWBinOpConverter - : public ::mlir::OpConversionPattern<::imex::dist::EWBinOp> { - using ::mlir::OpConversionPattern<::imex::dist::EWBinOp>::OpConversionPattern; - - /// Initialize the pattern. - void initialize() { - /// Signal that this pattern safely handles recursive application. - setHasBoundedRewriteRecursion(); - } - - // for lhs and rhs we generate an if-cascade which yields the view - // of the part which overlaps the current loop-slice. with static - // array shapes/offsets canonicalizer should eliminate - // conditions - void getPart(::mlir::OpBuilder &builder, ::mlir::Location loc, int64_t rank, - const EasyIdx &zero, const ::imex::ValVec &unitStrides, - ::mlir::ValueRange parts, - const ::mlir::SmallVector<::imex::ValVec> &shapes, - ::imex::EasyIdx &slcOff, const ::imex::EasyIdx &slcSz, - const ::imex::EasyIdx &pOff, unsigned i, - ::mlir::Value &resView) const { - auto pSz = easyIdx(loc, builder, shapes[i][0]); - auto vOff = slcOff - pOff; - auto pEnd = pOff + pSz; - - auto doNext = [&](const ::imex::EasyIdx &poff, - unsigned j) -> ::mlir::Value { - if (j < parts.size()) { - this->getPart(builder, loc, rank, zero, unitStrides, parts, shapes, - slcOff, slcSz, poff, j, resView); - return resView; - } else { - // we should never get here; create a array with recognizable shape - builder.create<::mlir::cf::AssertOp>( - loc, createInt(loc, builder, 0, 1), - "could not determine overlap of loop bounds and view"); - auto arType = - mlir::cast<::imex::ndarray::NDArrayType>(resView.getType()); - static int dbg = 47110000; - auto x = createIndex(loc, builder, ++dbg); - return builder - .create<::mlir::UnrealizedConversionCastOp>( - loc, cloneAsDynNonDist(arType), x) - .getResult(0); - } - }; - - // create a nested if-else-block returning a view with given args if - // condition cond is met, and returning orig resView otherwise (accepted - // as reference!) - auto ary = mlir::cast<::imex::ndarray::NDArrayType>(parts[i].getType()); - if (!(ary.hasUnitSize() || ary.hasZeroSize())) { - auto overlaps = slcOff.sge(pOff).land(slcOff.slt(pEnd)); - resView = - builder - .create<::mlir::scf::IfOp>( - loc, overlaps.get(), - [&](::mlir::OpBuilder &builder, ::mlir::Location loc) { - ::imex::ValVec vOffs(rank, zero.get()); - vOffs[0] = vOff.get(); - auto vShape = shapes[i]; - vShape[0] = slcSz.get(); - auto view = builder.create<::imex::ndarray::ExtractSliceOp>( - loc, parts[i], vOffs, vShape, unitStrides); - builder.create<::mlir::scf::YieldOp>(loc, view.getResult()); - }, - [&](::mlir::OpBuilder &builder, ::mlir::Location loc) { - builder.create<::mlir::scf::YieldOp>(loc, - doNext(pEnd, i + 1)); - }) - .getResult(0); - } else { - resView = doNext(pEnd, i + 1); - } - ++i; - }; - - ::mlir::LogicalResult - matchAndRewrite(::imex::dist::EWBinOp op, - ::imex::dist::EWBinOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - - auto lhsDistType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getLhs().getType()); - auto rhsDistType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getRhs().getType()); - auto resDistType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getResult().getType()); - // return failure if wrong ops or not distributed - if (!(lhsDistType && isDist(lhsDistType) && rhsDistType && - isDist(rhsDistType) && resDistType && isDist(resDistType))) { - return ::mlir::failure(); - } - - auto loc = op.getLoc(); - auto rank = resDistType.getRank(); - auto lhsRank = lhsDistType.getRank(); - auto rhsRank = rhsDistType.getRank(); - auto lhs = op.getLhs(); - auto rhs = op.getRhs(); - auto resGShape = resDistType.getShape(); - auto lhsNeedsView = - !(lhsDistType.hasUnitSize() || lhsDistType.hasZeroSize()); - auto rhsNeedsView = - !(rhsDistType.hasUnitSize() || rhsDistType.hasZeroSize()); - auto resArType = cloneAsDynNonDist(resDistType); - // auto core = adaptor.getCore(); - - ::imex::ValVec lhsParts, rhsParts; - int lhsOwnIdx, rhsOwnIdx; - // create array of parts, skip 0-sized parts - { - auto tmp = createPartsOf(loc, rewriter, lhs); - lhsOwnIdx = tmp.size() == 1 ? 0 : 1; - for (int i = 0; i < (int)tmp.size(); ++i) { - if (mlir::cast<::imex::ndarray::NDArrayType>(tmp[i].getType()) - .hasZeroSize()) { - if (i <= lhsOwnIdx) - --lhsOwnIdx; - } else { - lhsParts.emplace_back(tmp[i]); - } - } - tmp = createPartsOf(loc, rewriter, rhs); - rhsOwnIdx = tmp.size() == 1 ? 0 : 1; - for (int i = 0; i < (int)tmp.size(); ++i) { - if (mlir::cast<::imex::ndarray::NDArrayType>(tmp[i].getType()) - .hasZeroSize()) { - if (i <= rhsOwnIdx) - --rhsOwnIdx; - } else { - rhsParts.emplace_back(tmp[i]); - } - } - } - - if (lhsRank == 0 && rhsRank == 0) { - assert(lhsOwnIdx >= 0 && rhsOwnIdx >= 0); - rewriter.replaceOpWithNewOp<::imex::ndarray::EWBinOp>( - op, resArType, adaptor.getOp(), lhsParts[lhsOwnIdx], - rhsParts[rhsOwnIdx]); - return ::mlir::success(); - } - - auto zero = easyIdx(loc, rewriter, 0); - - // get global shape, offsets and team - auto dEnv = getDistEnv(lhsDistType); - auto team = dEnv.getTeam(); - ::imex::ValVec lOffs = adaptor.getTargetOffsets(); - if (lOffs.size() == 0 && resArType.getRank()) { - if (resDistType.hasUnitSize()) { - lOffs = ::imex::ValVec(resDistType.getRank(), zero.get()); - } else { - ::imex::ValVec gShape; - for (auto d : resGShape) { - gShape.emplace_back(createIndex(loc, rewriter, d)); - } - auto defPart = createDefaultPartition(loc, rewriter, team, gShape); - lOffs = defPart.getLOffsets(); - } - } - - ::imex::ValVec lhsOffs = createLocalOffsetsOf(loc, rewriter, lhs); - ::imex::ValVec rhsOffs = createLocalOffsetsOf(loc, rewriter, rhs); - - ::mlir::SmallVector<::imex::ValVec> lhsShapes, rhsShapes; - ::mlir::SmallVector loopStarts(1, zero); - auto theEnd = zero; - - if (lhsRank) { // insert bounds of lhs - for (auto p : lhsParts) { - auto shp = createShapeOf(loc, rewriter, p); - if (shp.size() && !lhsDistType.hasUnitSize()) { - loopStarts.emplace_back(loopStarts.back() + - easyIdx(loc, rewriter, shp[0])); - } - lhsShapes.emplace_back(std::move(shp)); - } - theEnd = theEnd.max(loopStarts.back()); - } - - if (rhsRank) { // insert bounds of rhs - auto prev = zero; - for (auto p : rhsParts) { - auto shp = createShapeOf(loc, rewriter, p); - if (shp.size() && !rhsDistType.hasUnitSize()) { - loopStarts.emplace_back(prev + easyIdx(loc, rewriter, shp[0])); - } - prev = loopStarts.back(); - rhsShapes.emplace_back(std::move(shp)); - } - theEnd = theEnd.max(loopStarts.back()); - } - - auto coreOff = ::imex::EasyIdx(loc, rewriter, ::mlir::Value{}); - auto coreEnd = coreOff; - if (true) { // insert bounds of core loop if provided - auto cOffs = adaptor.getCoreOffsets(); - if (cOffs.size()) { - auto cStart = easyIdx(loc, rewriter, cOffs[0]); - coreOff = cStart.min(theEnd); - loopStarts.emplace_back(coreOff); - auto cEnd = cStart + easyIdx(loc, rewriter, adaptor.getCoreSizes()[0]); - coreEnd = cEnd.min(theEnd); - loopStarts.emplace_back(coreEnd); - } - } - - // sort loops by start index - // generate compare-and-swap operation reflecting a bubble-sort - // because the first half and second half of the list are already sorted - // it is sufficient to reduce the outer loops to N/2 iterations - auto N = std::max(rhsParts.size(), lhsParts.size()); - // if we have a core, we need 2 more iterations - if (adaptor.getCoreOffsets().size() > 0) { - N += 2; - } - for (unsigned i = 0; i < N; ++i) { - for (unsigned j = 1; j < loopStarts.size(); ++j) { - auto a = loopStarts[j]; - auto b = loopStarts[j - 1]; - loopStarts[j - 1] = a.min(b); - loopStarts[j] = a.max(b); - } - } - ::mlir::SmallVector> loops; - for (auto i = 1u; i < loopStarts.size(); ++i) { - if (loopStarts[i].get() != loopStarts[i - 1].get()) { - loops.emplace_back(std::make_pair(loopStarts[i - 1], loopStarts[i])); - } - } - - // FIXME broadcasting - ::imex::ValVec resShape; - for (unsigned i = 0; i < rank; ++i) { - auto d = resGShape[i]; - assert(d >= 0); - if (i) { - resShape.emplace_back(createIndex(loc, rewriter, d)); - } else { - resShape.emplace_back(loops.back().second.get()); - } - } - - auto res = rewriter.create<::imex::ndarray::CreateOp>( - loc, resShape, ::imex::ndarray::fromMLIR(resDistType.getElementType()), - nullptr, getNonDistEnvs(resDistType)); - ::mlir::Value updatedRes = res.getResult(); - - ::imex::ValVec resOffs(rank, zero.get()); - ::imex::ValVec unitStrides(rank, createIndex(loc, rewriter, 1)); - - // for each loop slice, determine overlap with lhs and rhs - // apply to ndarray::ewbinop and insert into result array - auto createLoop = [&](const std::pair &lp, - const ::imex::EasyVal &cond) { - auto slcOff = lp.first; - auto slcSz = lp.second - slcOff; - - auto ifRes = rewriter.create<::mlir::scf::IfOp>( - loc, cond.land(slcSz.sgt(zero)).get(), - [&](::mlir::OpBuilder &builder, ::mlir::Location loc) { - // auto lhsOff = easyIdx(loc, rewriter, lhsOffs[0]); - // auto rhsOff = easyIdx(loc, rewriter, rhsOffs[0]); - - auto getUnitPart = - [&builder, - &loc](const ::mlir::ValueRange &parts) -> ::mlir::Value { - auto one = easyIdx(loc, builder, 1); - auto arType = mlir::cast<::imex::ndarray::NDArrayType>( - parts.front().getType()); - auto rtyp = arType.cloneWith( - ::mlir::SmallVector(arType.getRank(), 1), - arType.getElementType()); - auto getUnitPartImpl = - [&builder, &loc, &one, &parts, - &rtyp](unsigned i, auto getUnitPart_) -> ::mlir::Value { - if (i == parts.size() - 1) { - return builder.create<::imex::ndarray::CastOp>(loc, rtyp, - parts.back()); - } - auto p = parts[i]; - auto dims = - builder.createOrFold<::imex::ndarray::DimOp>(loc, p, 0); - auto cond = easyIdx(loc, builder, dims).eq(one); - return builder - .create<::mlir::scf::IfOp>( - loc, cond.get(), - [&](::mlir::OpBuilder &builder, ::mlir::Location loc) { - auto res = builder.create<::imex::ndarray::CastOp>( - loc, rtyp, p); - builder.create<::mlir::scf::YieldOp>(loc, - res.getResult()); - }, - [&](::mlir::OpBuilder &builder, ::mlir::Location loc) { - builder.create<::mlir::scf::YieldOp>( - loc, getUnitPart_(i + 1, getUnitPart_)); - }) - .getResult(0); - }; - return getUnitPartImpl(0, getUnitPartImpl); - }; - - ::mlir::Value lhsView = lhsParts.back(); - if (lhsRank && lhsNeedsView) { - getPart(builder, loc, rank, zero, unitStrides, lhsParts, - lhsShapes, slcOff, slcSz, zero, 0, lhsView); - } else if (lhsDistType.hasUnitSize()) { - lhsView = getUnitPart(lhsParts); - } else { - assert(lhsOwnIdx >= 0); - lhsView = lhsParts[lhsOwnIdx]; - } - - ::mlir::Value rhsView = rhsParts.back(); - if (rhsRank && rhsNeedsView) { - getPart(builder, loc, rank, zero, unitStrides, rhsParts, - rhsShapes, slcOff, slcSz, zero, 0, rhsView); - } else if (rhsDistType.hasUnitSize()) { - - rhsView = getUnitPart(rhsParts); - } else { - assert(lhsOwnIdx >= 0); - rhsView = rhsParts[rhsOwnIdx]; - } - - // we can now apply the ewbinop - auto opRes = builder.create<::imex::ndarray::EWBinOp>( - loc, resArType, op.getOp(), lhsView, rhsView); - // and copy the result intop the result array - resOffs[0] = slcOff.get(); - resShape[0] = slcSz.get(); - auto resAfterInsert = - builder.create<::imex::ndarray::ImmutableInsertSliceOp>( - loc, updatedRes, opRes, resOffs, resShape, unitStrides); - builder.create<::mlir::scf::YieldOp>(loc, - resAfterInsert.getResult()); - }, - [&](::mlir::OpBuilder &builder, ::mlir::Location loc) { - builder.create<::mlir::scf::YieldOp>(loc, updatedRes); - }); - return ifRes.getResult(0); - }; - - // create core loop first - auto easyTrue = ::imex::EasyVal(loc, rewriter, true); - if (coreOff.get()) { - updatedRes = createLoop({coreOff, coreEnd}, easyTrue); - } - - // all other loops - for (auto l : loops) { - // only need this loop if not core loop - auto cond = coreOff.get() ? coreOff.ne(l.first) : easyTrue; - updatedRes = createLoop(l, cond); - } - - // and init our new dist array - rewriter.replaceOp(op, createDistArray(loc, rewriter, team, resGShape, - lOffs, {updatedRes})); - - return ::mlir::success(); - } -}; - -/// Convert a global dist::EWUnyOp to ndarray::EWUnyOp on the local data. -/// Assumes that the partitioning of the inputs are properly aligned. -struct EWUnyOpConverter - : public ::mlir::OpConversionPattern<::imex::dist::EWUnyOp> { - using ::mlir::OpConversionPattern<::imex::dist::EWUnyOp>::OpConversionPattern; - - /// Initialize the pattern. - void initialize() { - /// Signal that this pattern safely handles recursive application. - setHasBoundedRewriteRecursion(); - } - - ::mlir::LogicalResult - matchAndRewrite(::imex::dist::EWUnyOp op, - ::imex::dist::EWUnyOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - - auto src = op.getSrc(); - auto srcDistType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType()); - auto resDistType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getResult().getType()); - // return failure if wrong ops or not distributed - if (!srcDistType || !isDist(srcDistType) || !resDistType || - !isDist(resDistType)) { - return ::mlir::failure(); - } - - auto loc = op.getLoc(); - auto resArType = cloneAsDynNonDist(resDistType); - auto lParts = createPartsOf(loc, rewriter, src); - auto lOffsets = createLocalOffsetsOf(loc, rewriter, src); - - ::imex::ValVec resParts; - // go through all parts and apply unyop - for (auto part : lParts) { - auto res = rewriter.create<::imex::ndarray::EWUnyOp>( - loc, resArType, adaptor.getOp(), part); - resParts.emplace_back(res.getResult()); - } - - // get global shape - auto gShape = resDistType.getShape(); - // and init our new dist array - rewriter.replaceOp(op, createDistArray(loc, rewriter, - getDistEnv(srcDistType).getTeam(), - gShape, lOffsets, resParts)); - - return ::mlir::success(); - } -}; - -/// Convert ndarray::CastElemTypeOp if operating on distributed arrays -struct CastElemTypeOpConverter - : public ::mlir::OpConversionPattern<::imex::ndarray::CastElemTypeOp> { - using ::mlir::OpConversionPattern< - ::imex::ndarray::CastElemTypeOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::CastElemTypeOp op, - ::imex::ndarray::CastElemTypeOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - - auto src = op.getInput(); - auto srcDistType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType()); - auto resDistType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getResult().getType()); - // return failure if wrong ops or not distributed - if (!srcDistType || !isDist(srcDistType) || !resDistType || - !isDist(resDistType)) { - return ::mlir::failure(); - } - - auto loc = op.getLoc(); - auto lParts = createPartsOf(loc, rewriter, src); - auto lOffsets = createLocalOffsetsOf(loc, rewriter, src); - - ::imex::ValVec resParts; - // go through all parts and apply cast - for (auto part : lParts) { - // infer result type: non-dist, same shape, modified elem type - auto partType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(part.getType()); - auto resArType = cloneAsNonDist(partType).cloneWith( - std::nullopt, resDistType.getElementType()); - auto castOp = rewriter.create<::imex::ndarray::CastElemTypeOp>( - loc, resArType, part); - resParts.emplace_back(castOp.getResult()); - } - - // get global shape - auto gShape = resDistType.getShape(); - // and init our new dist array - rewriter.replaceOp(op, createDistArray(loc, rewriter, - getDistEnv(srcDistType).getTeam(), - gShape, lOffsets, resParts)); - - return ::mlir::success(); - } -}; - -/// Replace ::imex::dist::InitDistArrayOp with unrealized_conversion_cast -/// InitDistArrayOp is a dummy op used only for propagating dist infos -struct InitDistArrayOpConverter - : public ::mlir::OpConversionPattern<::imex::dist::InitDistArrayOp> { - using ::mlir::OpConversionPattern< - ::imex::dist::InitDistArrayOp>::OpConversionPattern; - - void initialize() { - /// Signal that this pattern safely handles recursive application. - setHasBoundedRewriteRecursion(); - } - - ::mlir::LogicalResult - matchAndRewrite(::imex::dist::InitDistArrayOp op, - ::imex::dist::InitDistArrayOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto distType = - mlir::cast<::imex::ndarray::NDArrayType>(op.getResult().getType()); - if (!distType) { - return ::mlir::failure(); - } - rewriter.replaceOpWithNewOp<::imex::dist::InitDistArrayOp>( - op, typeConverter->convertType(op.getType()), adaptor.getLOffset(), - adaptor.getParts()); - return ::mlir::success(); - } -}; - -/// Convert ::imex::dist::PartsOfOp into respective operand of defining -/// op. We assume the defining op is a InitDistArrayOp or it was converted by a -/// unrealized_conversion_cast. -struct PartsOfOpConverter - : public ::mlir::OpConversionPattern<::imex::dist::PartsOfOp> { - using ::mlir::OpConversionPattern< - ::imex::dist::PartsOfOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::dist::PartsOfOp op, - typename ::imex::dist::PartsOfOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto base = adaptor.getArray(); - auto distType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getArray().getType()); - if (!distType || !isDist(distType)) { - return ::mlir::failure(); - } - - auto defOp = base.getDefiningOp(); - while (defOp && defOp->getNumOperands() == 1 && - ::mlir::isa<::mlir::UnrealizedConversionCastOp>(defOp)) { - defOp = defOp->getOperand(0).getDefiningOp(); - } - if (defOp) { - if (auto initOp = - ::mlir::dyn_cast<::imex::dist::InitDistArrayOp>(defOp)) { - rewriter.replaceOp(op, initOp.getParts()); - return ::mlir::success(); - } - } - // not a InitDistArrayOp - return ::mlir::failure(); - } -}; - -/// Convert ::imex::dist::LocalOffsetsOfOp into respective operand of defining -/// op. We assume the defining op is a InitDistArrayOp or it was converted by a -/// unrealized_conversion_cast. -struct LocalOffsetsOfOpConverter - : public ::mlir::OpConversionPattern<::imex::dist::LocalOffsetsOfOp> { - using ::mlir::OpConversionPattern< - ::imex::dist::LocalOffsetsOfOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::dist::LocalOffsetsOfOp op, - typename ::imex::dist::LocalOffsetsOfOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto base = adaptor.getArray(); - auto distType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getArray().getType()); - if (!distType || !isDist(distType)) { - return ::mlir::failure(); - } - - // 0d array - if (op.getNumResults() == 0) { - assert(distType.getRank() == 0); - rewriter.eraseOp(op); - return ::mlir::success(); - } - - auto defOp = base.getDefiningOp(); - while (defOp && defOp->getNumOperands() == 1 && - ::mlir::isa<::mlir::UnrealizedConversionCastOp>(defOp)) { - defOp = defOp->getOperand(0).getDefiningOp(); - } - if (defOp) { - if (auto initOp = - ::mlir::dyn_cast<::imex::dist::InitDistArrayOp>(defOp)) { - rewriter.replaceOp(op, initOp.getLOffset()); - return ::mlir::success(); - } - } - // not a InitDistArrayOp - return ::mlir::failure(); - } -}; - -/// Lowering ::imex::dist::DefaultPartitionOp: Compute default partition -/// for a given shape and number of processes. -/// We currently assume evenly split data. -/// We back-fill partitions if partitions are uneven (increase last to first -/// partition in prank-order by one additional item) -struct DefaultPartitionOpConverter - : public ::mlir::OpConversionPattern<::imex::dist::DefaultPartitionOp> { - using ::mlir::OpConversionPattern< - ::imex::dist::DefaultPartitionOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::dist::DefaultPartitionOp op, - ::imex::dist::DefaultPartitionOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - // FIXME: non-even partitions, ndims - auto gShape = adaptor.getGShape(); - int64_t rank = static_cast(gShape.size()); - - if (rank == 0) { - rewriter.eraseOp(op); - return ::mlir::success(); - } - - auto loc = op.getLoc(); - auto sz = easyIdx(loc, rewriter, gShape.front()); - auto np = easyIdx(loc, rewriter, adaptor.getNumProcs()); - auto pr = easyIdx(loc, rewriter, adaptor.getPRank()); - auto one = easyIdx(loc, rewriter, 1); - auto zero = easyIdx(loc, rewriter, 0); - - // compute tile size and local size (which can be greater) - auto rem = sz % np; - auto tSz = sz / np; - auto lSz = tSz + (pr + rem).sge(np).select(one, zero); - auto lOff = (pr * tSz) + zero.max(rem - (np - pr)); - - // store in result range - ::imex::ValVec res(2 * rank, zero.get()); - res[0] = lOff.get(); - res[rank] = lSz.max(zero).get(); - for (int64_t i = 1; i < rank; ++i) { - res[rank + i] = gShape[i]; - } - - rewriter.replaceOp(op, res); - return ::mlir::success(); - } -}; - -// Compute the overlap of local data and global slice and return -// as target part (global offset/size relative to requested slice) -// Currently only dim0 is cut, hence offs/sizes of all other dims -// will be identical to the ones of the requested slice -// (e.g. same size and offset 0) -struct LocalTargetOfSliceOpConverter - : public ::mlir::OpConversionPattern<::imex::dist::LocalTargetOfSliceOp> { - using ::mlir::OpConversionPattern< - ::imex::dist::LocalTargetOfSliceOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::dist::LocalTargetOfSliceOp op, - ::imex::dist::LocalTargetOfSliceOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto distType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getArray().getType()); - if (!(distType && isDist(distType))) - return ::mlir::failure(); - - auto loc = op.getLoc(); - auto src = op.getArray(); - auto slcOffs = adaptor.getOffsets(); - auto slcSizes = adaptor.getSizes(); - auto slcStrides = adaptor.getStrides(); - int64_t rank = slcOffs.size(); - - // Get the local part of the global slice - auto lOffs = createLocalOffsetsOf(loc, rewriter, src); - auto lParts = createPartsOf(loc, rewriter, src); - ::imex::ValVec lShape = createShapeOf(loc, rewriter, lParts.front()); - for (unsigned p = 1; p < lParts.size(); ++p) { - auto pShape = createShapeOf(loc, rewriter, lParts[p]); - lShape[0] = rewriter.createOrFold<::mlir::arith::AddIOp>(loc, lShape[0], - pShape[0]); - } - - auto ovlp = createOverlap<::mlir::ValueRange, ::imex::ValVec>( - loc, rewriter, lOffs, lShape, slcOffs, slcSizes, slcStrides); - auto lOff = std::get<2>(ovlp); - auto lSzs = std::get<1>(ovlp); - - ::imex::ValVec results(rank * 2, createIndex(loc, rewriter, 0)); - results[0 * rank] = lOff[0]; - results[1 * rank] = lSzs[0]; - - for (auto i = 1; i < rank; ++i) { - results[1 * rank + i] = slcSizes[i]; - } - - rewriter.replaceOp(op, results); - return ::mlir::success(); - } -}; - -/// Convert ::imex::dist::LocalBoundingBoxOp -/// 1. Computes offset and sizes of the provided slice when mapped to provided -/// target. -/// 2. If a bounding box is provided, computes the bounding box for it and the -/// result of 1. -struct LocalBoundingBoxOpConverter - : public ::mlir::OpConversionPattern<::imex::dist::LocalBoundingBoxOp> { - using ::mlir::OpConversionPattern< - ::imex::dist::LocalBoundingBoxOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::dist::LocalBoundingBoxOp op, - ::imex::dist::LocalBoundingBoxOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - - auto loc = op.getLoc(); - auto inner = op.getInner(); - assert(!inner); - auto vOffs = op.getOffsets(); - auto vSizes = op.getSizes(); - auto vStrides = op.getStrides(); - auto tOffs = op.getTargetOffsets(); - auto tSizes = op.getTargetSizes(); - auto bbOffs = op.getBBOffsets(); - auto bbSizes = op.getBBSizes(); - bool hasBB = !bbOffs.empty(); - - auto rank = vOffs.size(); - - // min start index (among all views) for each dim followed by sizes - ::imex::ValVec oprnds(rank * 2); - auto one = easyIdx(loc, rewriter, 1); - auto zero = easyIdx(loc, rewriter, 0); - - // for each dim and view compute min offset and max end - // return min offset and size (assuming stride 1 for the bb) - for (size_t i = 0; i < rank; ++i) { - ::mlir::SmallVector doffs; - ::mlir::SmallVector dends; - auto tOff = easyIdx(loc, rewriter, tOffs[i]); - auto tSz = easyIdx(loc, rewriter, tSizes[i]); - auto vOff = easyIdx(loc, rewriter, vOffs[i]); - auto vSz = easyIdx(loc, rewriter, vSizes[i]); - auto vSt = easyIdx(loc, rewriter, vStrides[i]); - - auto ttOff = vOff + tOff * vSt; - auto ttEnd = ttOff + (tSz * vSt) - (vSt - one); - auto has_tSz = tSz.sgt(zero); // the target might have size 0 - - auto bbSz = hasBB ? easyIdx(loc, rewriter, bbSizes[i]) : zero; - auto has_bbSz = bbSz.sgt(zero); // BB can have size 0 if BB had tSz 0 - auto bbOff = - hasBB ? has_bbSz.select(easyIdx(loc, rewriter, bbOffs[i]), ttOff) - : ttOff; - auto bbEnd = bbOff + bbSz; - - auto vEnd = vOff + (vSz * vSt) - (vSt - one); // one after last element - - auto off = has_tSz.select(vOff.max(ttOff), bbOff).min(bbOff); - auto end = has_tSz.select(vEnd.min(ttEnd), vEnd).max(bbEnd); - auto sz = has_tSz.select(end - off, bbSz); - - oprnds[i] = off.get(); - oprnds[i + rank] = sz.get(); - } - - rewriter.replaceOp(op, oprnds); - return ::mlir::success(); - } -}; - -struct ExtendHaloForSliceOpConverter - : public ::mlir::OpConversionPattern<::imex::dist::ExtendHaloForSliceOp> { - using ::mlir::OpConversionPattern< - ::imex::dist::ExtendHaloForSliceOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::dist::ExtendHaloForSliceOp op, - ::imex::dist::ExtendHaloForSliceOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - ::mlir::SymbolTableCollection symbolTable; - auto meshOp = ::mlir::mesh::getMesh(op, symbolTable); - if (!meshOp) { - return ::mlir::failure(); - } - - // compute number of shards along split axes - // compute sharded dims extends (element count per sharded dim of base array) - ::mlir::SmallVector numShards, shardedDims; - auto baseShape = op.getStaticShape(); - for (auto dim = 0; dim<(int64_t)op.getSplitAxes().size(); ++dim) { - auto axes = op.getSplitAxes().getAxes()[dim]; - if(!axes.empty()) { - numShards.emplace_back(::mlir::mesh::collectiveProcessGroupSize(axes.asArrayRef(), meshOp)); - assert(!::mlir::ShapedType::isDynamic(numShards.back())); - shardedDims.emplace_back(dim); - } - } - - // init halo sizes either from input or to 0 - ::mlir::SmallVector<::imex::EasyI64> haloSizes; - auto zero = easyI64(loc, rewriter, 0); - auto one = easyI64(loc, rewriter, 1); - if (op.getHaloSizes().empty()) { - haloSizes.resize(numShards.size()*2, zero); - } else { - assert(op.getHaloSizes().size() == numShards.size()*2); - for (auto sz : op.getHaloSizes()) { - haloSizes.emplace_back(easyI64(loc, rewriter, sz)); - } - } - - // iterate split axes and compute lower/upper halo bounds for each dim - int64_t curr = 0; - auto targetDimsOffs = op.getShardedDimsOffsets(); - for (size_t dim=0; dim { - using ::mlir::OpConversionPattern< - ::imex::dist::LocalCoreOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::dist::LocalCoreOp op, - ::imex::dist::LocalCoreOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto src = op.getArray(); - - auto distType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType()); - if (!distType || !isDist(distType)) - return ::mlir::failure(); - - auto rank = distType.getRank(); - if (rank == 0) { - rewriter.eraseOp(op); - return ::mlir::success(); - } - int64_t resRank = op.getResults().size() / 2; - assert(resRank == rank); - - // local part, its offsets and shape - auto loc = op.getLoc(); - auto lOffsets = createLocalOffsetsOf(loc, rewriter, src); - auto lParts = createPartsOf(loc, rewriter, src); - unsigned ownPartIdx = lParts.size() == 1 ? 0 : 1; - ::mlir::Value lData = lParts[ownPartIdx]; - ::imex::ValVec lSizes = createShapeOf(loc, rewriter, lData); - if (ownPartIdx) { - ::mlir::Value lhData = lParts[0]; - ::imex::ValVec lhSizes = createShapeOf(loc, rewriter, lhData); - lOffsets[0] = (easyIdx(loc, rewriter, lOffsets[0]) + - easyIdx(loc, rewriter, lhSizes[0])) - .get(); - } - - ::imex::ValVec oprnds(rank * 2); - - auto cOffs = op.getCoreOffsets(); - auto cSizes = op.getCoreSizes(); - auto tOffs = op.getTargetOffsets(); - auto tSizes = op.getTargetSizes(); - auto sOffs = op.getSliceOffsets(); - auto sSizes = op.getSliceSizes(); - auto sStrs = op.getSliceStrides(); - - auto overlap = createOverlap<::imex::ValVec>(loc, rewriter, lOffsets, - lSizes, sOffs, sSizes, sStrs); - auto oOffs = std::get<2>(overlap); - auto oSizes = std::get<1>(overlap); - auto zero = easyIdx(loc, rewriter, 0); - - // for each dim compute max offset and min end - for (auto i = 0; i < rank; ++i) { - auto oOff = easyIdx(loc, rewriter, oOffs[i]); - auto oSz = easyIdx(loc, rewriter, oSizes[i]); - auto tOff = easyIdx(loc, rewriter, tOffs[i]); - auto tSz = easyIdx(loc, rewriter, tSizes[i]); - auto cOff = cOffs.size() ? easyIdx(loc, rewriter, cOffs[i]) : zero; - auto cSz = cSizes.size() ? easyIdx(loc, rewriter, cSizes[i]) : tSz; - - auto shift = tOff - oOff; - // the updated core offset is max of old and current - auto rOff = cOff.max(zero - shift); - - // the local remainder starting at tOff - auto lRemain = oSz - shift; - // the local max loop sz is - auto lMax = lRemain - rOff; - // the target local max loop sz is - auto tMax = tSz - rOff; - - // the updated core size is the diff of updated core off and min end - auto rSz = (cOff + cSz - rOff).min(lMax.min(tMax)); - - oprnds[i] = rOff.get(); // cOff.max(off).get(); - oprnds[i + rank] = rSz.get(); - } - - rewriter.replaceOp(op, oprnds); - return ::mlir::success(); - } -}; - -/// Convert ::imex::dist::RePartitionOp -/// Creates a new array from the input array by re-partitioning it -/// according to the target part (or default). The repartitioning -/// itself happens in a library call. -struct RePartitionOpConverter - : public ::mlir::OpConversionPattern<::imex::dist::RePartitionOp> { - using ::mlir::OpConversionPattern< - ::imex::dist::RePartitionOp>::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::dist::RePartitionOp op, - ::imex::dist::RePartitionOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto base = op.getArray(); - - auto distType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(base.getType()); - if (!distType || !isDist(distType)) - return ::mlir::failure(); - - auto loc = op.getLoc(); - auto rank = distType.getRank(); - ::imex::ValVec bbOffs = op.getTargetOffsets(); - ::imex::ValVec bbSizes = op.getTargetSizes(); - - // Get required info from base - auto dEnv = getDistEnv(distType); - auto team = dEnv.getTeam(); - auto gShape = createGlobalShapeOf(loc, rewriter, base); - auto sGShape = distType.getShape(); - auto lOffsets = createLocalOffsetsOf(loc, rewriter, base); - auto lParts = createPartsOf(loc, rewriter, base); - // auto rank = gShape.size(); - - auto zero = easyIdx(loc, rewriter, 0); - auto one = easyIdx(loc, rewriter, 1); - - // default target partition is balanced - if (bbSizes.empty()) { - if (distType.hasUnitSize()) { - bbOffs = ::imex::ValVec(rank, zero.get()); - bbSizes = ::imex::ValVec(rank, one.get()); - } else { - auto lPart = createDefaultPartition(loc, rewriter, team, gShape); - bbOffs = lPart.getLOffsets(); - bbSizes = lPart.getLShape(); - } - } - - // which is the part that we own? - assert(lParts.size() == 1 || lParts.size() == 3 || - (false && "Number of local parts must be 1 or 3")); - unsigned ownPartIdx = lParts.size() == 1 ? 0 : 1; - - // Get offsets and shapes of parts - ::mlir::Value lData = lParts[ownPartIdx]; - ::imex::ValVec lSizes = createShapeOf(loc, rewriter, lData); - - // determine overlap of new local part, we split dim 0 only - auto bbOff = easyIdx(loc, rewriter, bbOffs[0]); - auto bbSize = easyIdx(loc, rewriter, bbSizes[0]); - auto oldOff = easyIdx(loc, rewriter, lOffsets[0]); - if (ownPartIdx) { - auto lHShape = createShapeOf(loc, rewriter, lParts[0]); - oldOff = oldOff + easyIdx(loc, rewriter, lHShape[0]); - } - auto oldSize = easyIdx(loc, rewriter, lSizes[0]); - auto tEnd = bbOff + bbSize; - auto oldEnd = oldOff + oldSize; - auto ownOff = oldOff.max(bbOff); - auto ownSize = (oldEnd.min(tEnd) - ownOff).max(zero); - - // compute left and right halo sizes, we split dim 0 only - // FIXME device - ::imex::ValVec lHSizes(bbSizes), rHSizes(bbSizes); - if (distType.hasUnitSize()) { - lHSizes[0] = - oldSize.eq(zero).land(oldOff.sgt(zero)).select(one, zero).get(); - rHSizes[0] = - oldSize.eq(zero).land(oldOff.sle(zero)).select(one, zero).get(); - } else { - lHSizes[0] = (ownOff.min(tEnd) - bbOff).get(); - rHSizes[0] = (tEnd - (ownOff + ownSize)).max(zero).get(); - } - - auto upHa = rewriter.create<::imex::distruntime::GetHaloOp>( - loc, lData, gShape, lOffsets, bbOffs, bbSizes, lHSizes, rHSizes, team); - - // create subview of local part - ::mlir::Value ownView = lData; - if (!distType.hasUnitSize()) { - ::imex::ValVec vSizes = bbSizes; - vSizes[0] = ownSize.get(); - ::imex::ValVec vOffs = bbOffs; - vOffs[0] = (ownOff - oldOff).get(); - ::imex::ValVec unitStrides(distType.getRank(), - createIndex(loc, rewriter, 1)); - ownView = rewriter.create<::imex::ndarray::SubviewOp>( - loc, lData, vOffs, vSizes, unitStrides); - } - - // generate call to wait for halos - // An optimizing pass might move this to the first use of a halo part - rewriter.create<::imex::distruntime::WaitOp>(loc, upHa.getHandle()); - - // init dist array - rewriter.replaceOp( - op, createDistArray(loc, rewriter, team, sGShape, bbOffs, - {upHa.getLHalo(), ownView, upHa.getRHalo()})); - - return ::mlir::success(); - } -}; - -/// Convert a global ndarray::PermuteDimsOp on a distributed array -/// to ndarray::PermuteDimsOp on the local data. -/// If needed, adds a repartition op. -/// The local partition (e.g. a RankedTensor) is wrapped in a -/// non-distributed NDArray and re-applied to PermuteDimsOp. -/// op gets replaced with global distributed array -struct PermuteDimsOpConverter - : public ::mlir::OpConversionPattern<::imex::ndarray::PermuteDimsOp> { - using ::mlir::OpConversionPattern< - ::imex::ndarray::PermuteDimsOp>::OpConversionPattern; - - /// Initialize the pattern. - void initialize() { - /// Signal that this pattern safely handles recursive application. - setHasBoundedRewriteRecursion(); - } - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::PermuteDimsOp op, - ::imex::ndarray::PermuteDimsOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - - auto src = op.getSource(); - auto dst = op.getResult(); - auto srcType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType()); - auto dstType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(dst.getType()); - if (!(srcType && isDist(srcType) && dstType && isDist(dstType))) { - return ::mlir::failure(); - } - - auto loc = op.getLoc(); - auto srcEnv = getDistEnv(srcType); - auto team = srcEnv.getTeam(); - auto elementType = srcType.getElementType(); - - auto srcGShape = createGlobalShapeOf(loc, rewriter, src); - auto srcLParts = createPartsOf(loc, rewriter, src); - auto srcLArray = srcLParts.size() == 1 ? srcLParts[0] : srcLParts[1]; - auto srcLOffsets = createLocalOffsetsOf(loc, rewriter, src); - - auto dstGShape = createGlobalShapeOf(loc, rewriter, dst); - auto dstLPart = createDefaultPartition(loc, rewriter, team, dstGShape); - auto dstLOffsets = dstLPart.getLOffsets(); - auto dstLShape = dstLPart.getLShape(); - auto dstLShapeIndex = getShapeFromValues(dstLShape); - auto dstLType = ::imex::ndarray::NDArrayType::get( - dstLShapeIndex, elementType, getNonDistEnvs(dstType)); - - // call the dist runtime - auto handleType = ::imex::distruntime::AsyncHandleType::get(getContext()); - auto distLArray = rewriter.create<::imex::distruntime::CopyPermuteOp>( - loc, ::mlir::TypeRange{handleType, dstLType}, team, srcLArray, - srcGShape, srcLOffsets, dstLOffsets, dstLShape, adaptor.getAxes()); - (void)rewriter.create<::imex::distruntime::WaitOp>(loc, - distLArray.getHandle()); - // finally init dist array - rewriter.replaceOp( - op, createDistArray(loc, rewriter, team, dstGShape, dstLOffsets, - ::mlir::ValueRange{distLArray.getNlArray()})); - - return ::mlir::success(); - } -}; - -// ******************************* -// ***** Pass infrastructure ***** -// ******************************* - -// Full Pass -struct ConvertDistToStandardPass - : public ::imex::impl::ConvertDistToStandardBase< - ConvertDistToStandardPass> { - ConvertDistToStandardPass() = default; - - void runOnOperation() override { - auto &ctxt = getContext(); - ::mlir::TypeConverter typeConverter; - - // Convert unknown types to itself - typeConverter.addConversion([](::mlir::Type type) { return type; }); - - // distributed array gets converted into its individual members - typeConverter.addConversion([&ctxt](::imex::ndarray::NDArrayType type) - -> std::optional<::mlir::Type> { - if (auto dEnv = getDistEnv(type)) { - ::mlir::SmallVector<::mlir::Type> types; - auto rank = type.getRank(); - if (rank) { - for (auto pttyp : getPartsTypes(type)) { - types.emplace_back(pttyp); // parts - } - auto mrTyp = ::mlir::MemRefType::get(::std::array{rank}, - ::mlir::IndexType::get(&ctxt)); - types.emplace_back(mrTyp); // loffs - } else { - auto pts = getPartsTypes(type); - types.emplace_back(pts[pts.size() == 1 ? 0 : 1]); - } - return ::mlir::TupleType::get(&ctxt, types); - } - return type; - }); - - auto materializeArray = - [&](::mlir::OpBuilder &builder, ::imex::ndarray::NDArrayType type, - ::mlir::ValueRange inputs, - ::mlir::Location loc) -> ::mlir::Value { - assert(inputs.size() == 1); - auto input = inputs[0]; - auto itype = input.getType(); - auto ary = mlir::dyn_cast<::imex::ndarray::NDArrayType>(itype); - if (type != itype && ary) { - if (isDist(ary)) { - assert(ary.getRank() == 0); - auto parts = createPartsOf(loc, builder, input); - assert(parts.size() == 1); - return parts[0]; - } else { - return builder.create<::imex::ndarray::CastOp>(loc, type, input) - .getResult(); - } - } - return builder - .create<::mlir::UnrealizedConversionCastOp>(loc, type, inputs) - .getResult(0); - }; - - typeConverter.addSourceMaterialization(materializeArray); - - // we need two passes because argument materialization goes after all the - // other conversions. The first part converts all dist stuff except - // InitDistArrayOp which should then have no use. In the second pass we - // erase all InitDistArrayOps - - ::mlir::ConversionTarget target(ctxt); - target.addIllegalDialect<::imex::dist::DistDialect>(); - target.addLegalDialect< - ::imex::distruntime::DistRuntimeDialect, ::mlir::func::FuncDialect, - ::mlir::linalg::LinalgDialect, ::mlir::arith::ArithDialect, - ::imex::ndarray::NDArrayDialect, ::mlir::tensor::TensorDialect, - ::mlir::memref::MemRefDialect, ::mlir::cf::ControlFlowDialect, - ::mlir::bufferization::BufferizationDialect, - ::imex::region::RegionDialect>(); - target.addLegalOp<::mlir::UnrealizedConversionCastOp>(); // FIXME - - // make sure function boundaries get converted - target.addDynamicallyLegalOp<::mlir::func::FuncOp>( - [&](::mlir::func::FuncOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()); - }); - target.addDynamicallyLegalOp< - ::mlir::func::CallOp, ::imex::ndarray::ReshapeOp, - ::imex::ndarray::InsertSliceOp, ::imex::ndarray::EWBinOp, - ::imex::ndarray::EWUnyOp, ::imex::ndarray::LinSpaceOp, - ::imex::ndarray::CreateOp, ::imex::ndarray::CopyOp, - ::imex::ndarray::ReductionOp, ::imex::ndarray::ToTensorOp, - ::imex::ndarray::DeleteOp, ::imex::ndarray::CastElemTypeOp, - ::imex::region::EnvironmentRegionOp, - ::imex::region::EnvironmentRegionYieldOp, - ::imex::ndarray::PermuteDimsOp>( - [&](::mlir::Operation *op) { return typeConverter.isLegal(op); }); - target.addLegalOp<::imex::dist::InitDistArrayOp>(); - - // All the dist conversion patterns/rewriter - ::mlir::RewritePatternSet patterns(&ctxt); - // all these patterns are converted - patterns - .insert( - typeConverter, &ctxt); - mlir::scf::populateSCFStructuralTypeConversionsAndLegality( - typeConverter, patterns, target); - ::imex::populateRegionTypeConversionPatterns(patterns, typeConverter); - - // Let's go! - if (::mlir::failed(::mlir::applyPartialConversion(getOperation(), target, - ::std::move(patterns)))) { - signalPassFailure(); - } - - // now remove all InitDistArrayOps - getOperation()->walk( - [&](::imex::dist::InitDistArrayOp op) { op->erase(); }); - } -}; - -} // namespace -} // namespace dist - -/// Populate the given list with patterns that convert Dist to Standard -void populateDistToStandardConversionPatterns( - ::mlir::LLVMTypeConverter &converter, ::mlir::RewritePatternSet &patterns) { - assert(false); -} - -/// Create a pass that convert Dist to Standard -std::unique_ptr<::mlir::Pass> createConvertDistToStandardPass() { - return std::make_unique<::imex::dist::ConvertDistToStandardPass>(); -} - -} // namespace imex diff --git a/lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp b/lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp index 9c0fbbb04..13bef311c 100644 --- a/lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp +++ b/lib/Conversion/NDArrayToLinalg/NDArrayToLinalg.cpp @@ -25,26 +25,17 @@ //===----------------------------------------------------------------------===// #include -#include #include #include #include #include #include -#include -#include #include -#include -#include #include #include #include -#include -#include -#include #include -#include #include #include @@ -56,28 +47,29 @@ namespace imex { namespace imex { -/// @return type without a sign -static mlir::Type makeSignlessType(mlir::Type type) { - if (auto intType = mlir::dyn_cast(type)) { - if (!intType.isSignless()) - return mlir::IntegerType::get(intType.getContext(), intType.getWidth()); - } - return type; -} - -/// @return operand cast to signless type if needed, val if not -static mlir::Value doSignCast(mlir::OpBuilder &builder, mlir::Location &loc, - mlir::Value val) { - auto origType = val.getType(); - auto signlessType = makeSignlessType(origType); - if (signlessType != origType) { - val = - builder - .create<::mlir::UnrealizedConversionCastOp>(loc, signlessType, val) - .getResult(0); - } - return val; -} +// /// @return type without a sign +// static mlir::Type makeSignlessType(mlir::Type type) { +// if (auto intType = mlir::dyn_cast(type)) { +// if (!intType.isSignless()) +// return mlir::IntegerType::get(intType.getContext(), +// intType.getWidth()); +// } +// return type; +// } + +// /// @return operand cast to signless type if needed, val if not +// static mlir::Value doSignCast(mlir::OpBuilder &builder, mlir::Location &loc, +// mlir::Value val) { +// auto origType = val.getType(); +// auto signlessType = makeSignlessType(origType); +// if (signlessType != origType) { +// val = +// builder +// .create<::mlir::UnrealizedConversionCastOp>(loc, signlessType, +// val) .getResult(0); +// } +// return val; +// } /// Create a linalg generic op from given output, input and body template @@ -100,63 +92,90 @@ auto createParFor(mlir::Location &loc, mlir::OpBuilder &builder, uint64_t rank, namespace { -struct CastLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::CastOp> { - using OpConversionPattern::OpConversionPattern; +/// Convert ndarray.copy and its return type to memref.alloc + memref.copy. +struct CopyLowering : public ::mlir::OpRewritePattern<::imex::ndarray::CopyOp> { + using OpRewritePattern::OpRewritePattern; ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::CastOp op, - ::imex::ndarray::CastOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto *converter = getTypeConverter(); - assert(converter && "Type converter is not set"); - auto src = adaptor.getSource(); - auto inTyp = - mlir::dyn_cast<::mlir::RankedTensorType>(adaptor.getSource().getType()); - auto outTyp = mlir::dyn_cast<::mlir::RankedTensorType>( - converter->convertType(op.getType())); - - if (!inTyp || !outTyp) { + matchAndRewrite(::imex::ndarray::CopyOp op, + ::mlir::PatternRewriter &rewriter) const override { + // check output type and get operands + auto srcArTyp = + mlir::dyn_cast<::mlir::RankedTensorType>(op.getSource().getType()); + auto retArTyp = mlir::dyn_cast<::mlir::RankedTensorType>(op.getType()); + if (!(srcArTyp && retArTyp)) { return ::mlir::failure(); } - if (outTyp == inTyp) { - rewriter.replaceOp(op, src); - } else { - rewriter.replaceOpWithNewOp<::mlir::tensor::CastOp>(op, outTyp, src); + auto loc = op.getLoc(); + auto src = op.getSource(); + auto rank = srcArTyp.getRank(); + ::imex::ValVec dynDims; + + // get dynamic shape + auto tTyp = mlir::cast<::mlir::TensorType>(src.getType()); + for (int64_t i = 0; i < rank; ++i) { + if (tTyp.isDynamicDim(i)) { + dynDims.emplace_back( + rewriter.createOrFold<::mlir::tensor::DimOp>(loc, src, i)); + } + } + // alloc memref + auto mrTyp = + ::mlir::MemRefType::get(tTyp.getShape(), tTyp.getElementType()); + auto mr = rewriter.create<::mlir::memref::AllocOp>( + loc, mrTyp, dynDims, rewriter.getI64IntegerAttr(8)); + // and copy if non-0 + if (!imex::ndarray::hasZeroSize(retArTyp.getShape())) { + auto srcMR = createToMemRef(loc, rewriter, src, getMemRefType(srcArTyp)); + // create a region with given env, add copy op within it + auto env = rewriter.getStringAttr("protect_copy_op"); + rewriter.create<::imex::region::EnvironmentRegionOp>( + loc, env, std::nullopt, std::nullopt, + [&srcMR, &mr](::mlir::OpBuilder &builder, ::mlir::Location loc) { + (void)builder.create<::mlir::memref::CopyOp>(loc, srcMR, mr); + (void)builder.create<::imex::region::EnvironmentRegionYieldOp>(loc); + }); } + // convert memref to tensor + auto res = rewriter.create<::mlir::bufferization::ToTensorOp>( + loc, retArTyp, mr, /*restrict=*/true, + /*writable=*/true); + rewriter.replaceOp(op, res); + return ::mlir::success(); } }; -/// Convert FromMemRefOp to bufferize.to_tensor -struct FromMemRefLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::FromMemRefOp> { - using OpConversionPattern::OpConversionPattern; +/// Convert NDArray's ReshapeOp and its return type to Linalg/tensor. +/// Optionally creates a copy first. +struct ReshapeLowering + : public ::mlir::OpRewritePattern<::imex::ndarray::ReshapeOp> { + using OpRewritePattern::OpRewritePattern; ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::FromMemRefOp op, - ::imex::ndarray::FromMemRefOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - - rewriter.replaceOpWithNewOp<::mlir::bufferization::ToTensorOp>( - op, adaptor.getInput(), /*restrict=*/true); - - return ::mlir::success(); - } -}; + matchAndRewrite(::imex::ndarray::ReshapeOp op, + ::mlir::PatternRewriter &rewriter) const override { + // check output type and get operands + auto retArTyp = mlir::dyn_cast<::mlir::RankedTensorType>(op.getType()); + auto srcArTyp = + mlir::dyn_cast<::mlir::RankedTensorType>(op.getSource().getType()); + if (!(retArTyp && srcArTyp)) { + return ::mlir::failure(); + } -/// Lower to the input operand of the defining op. -struct ToTensorLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::ToTensorOp> { - using OpConversionPattern::OpConversionPattern; + auto loc = op.getLoc(); + auto src = op.getSource(); + auto shape = op.getShape(); - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::ToTensorOp op, - ::imex::ndarray::ToTensorOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { + if (op.getCopy().value_or(false)) { + src = rewriter.create<::imex::ndarray::CopyOp>(loc, srcArTyp, + op.getSource()); + } - rewriter.replaceOp(op, adaptor.getInput()); + auto shapeT = rewriter.create<::mlir::tensor::FromElementsOp>(loc, shape); + rewriter.replaceOpWithNewOp<::mlir::tensor::ReshapeOp>(op, retArTyp, src, + shapeT); return ::mlir::success(); } @@ -165,15 +184,14 @@ struct ToTensorLowering /// Convert NDArray's subview to memref::subview. /// Adjusted from NTensor struct SubviewLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::SubviewOp> { - using OpConversionPattern::OpConversionPattern; + : public ::mlir::OpRewritePattern<::imex::ndarray::SubviewOp> { + using OpRewritePattern::OpRewritePattern; ::mlir::LogicalResult matchAndRewrite(::imex::ndarray::SubviewOp op, - ::imex::ndarray::SubviewOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { + ::mlir::PatternRewriter &rewriter) const override { - auto srcTnsr = adaptor.getSource(); + auto srcTnsr = op.getSource(); auto loc = op->getLoc(); // convert src array to memref @@ -187,12 +205,12 @@ struct SubviewLowering srcArType.getElementType()); auto srcMR = createToMemRef(loc, rewriter, srcTnsr, srcMRType); - auto offsets = ::mlir::getMixedValues(adaptor.getStaticOffsets(), - adaptor.getOffsets(), rewriter); - auto sizes = ::mlir::getMixedValues(adaptor.getStaticSizes(), - adaptor.getSizes(), rewriter); - auto strides = ::mlir::getMixedValues(adaptor.getStaticStrides(), - adaptor.getStrides(), rewriter); + auto offsets = ::mlir::getMixedValues(op.getStaticOffsets(), + op.getOffsets(), rewriter); + auto sizes = + ::mlir::getMixedValues(op.getStaticSizes(), op.getSizes(), rewriter); + auto strides = ::mlir::getMixedValues(op.getStaticStrides(), + op.getStrides(), rewriter); auto resMRType = mlir::cast<::mlir::MemRefType>( ::mlir::memref::SubViewOp::inferRankReducedResultType( @@ -213,23 +231,22 @@ struct SubviewLowering /// Convert NDArray's extract_slice to tensor.extract_slice. struct ExtractSliceLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::ExtractSliceOp> { - using OpConversionPattern::OpConversionPattern; + : public ::mlir::OpRewritePattern<::imex::ndarray::ExtractSliceOp> { + using OpRewritePattern::OpRewritePattern; ::mlir::LogicalResult matchAndRewrite(::imex::ndarray::ExtractSliceOp op, - ::imex::ndarray::ExtractSliceOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { + ::mlir::PatternRewriter &rewriter) const override { - auto srcTnsr = adaptor.getSource(); + auto srcTnsr = op.getSource(); auto loc = op->getLoc(); - auto offsets = ::mlir::getMixedValues(adaptor.getStaticOffsets(), - adaptor.getOffsets(), rewriter); - auto sizes = ::mlir::getMixedValues(adaptor.getStaticSizes(), - adaptor.getSizes(), rewriter); - auto strides = ::mlir::getMixedValues(adaptor.getStaticStrides(), - adaptor.getStrides(), rewriter); + auto offsets = ::mlir::getMixedValues(op.getStaticOffsets(), + op.getOffsets(), rewriter); + auto sizes = + ::mlir::getMixedValues(op.getStaticSizes(), op.getSizes(), rewriter); + auto strides = ::mlir::getMixedValues(op.getStaticStrides(), + op.getStrides(), rewriter); auto res = rewriter.create<::mlir::tensor::ExtractSliceOp>( loc, srcTnsr, offsets, sizes, strides); @@ -239,68 +256,19 @@ struct ExtractSliceLowering } }; -/// Convert NDArray's DimOp to tensor::DimOp. -struct DimOpLowering : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(imex::ndarray::DimOp op, - imex::ndarray::DimOp::Adaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto srcTnsr = adaptor.getSource(); - auto srcType = mlir::dyn_cast<::mlir::TensorType>(srcTnsr.getType()); - if (!srcType) - return mlir::failure(); - - rewriter.replaceOpWithNewOp<::mlir::tensor::DimOp>(op, srcTnsr, - adaptor.getIndex()); - return mlir::success(); - } -}; - -/// Convert NDArray's LoadOp to tensor::ExtractOp. -/// Adjusted from NTensor -struct LoadOpLowering - : public mlir::OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - mlir::LogicalResult - matchAndRewrite(imex::ndarray::LoadOp op, - imex::ndarray::LoadOp::Adaptor adaptor, - mlir::ConversionPatternRewriter &rewriter) const override { - auto srcArType = - mlir::cast(op.getArray().getType()); - auto srcTnsr = adaptor.getArray(); - if (!mlir::isa(srcTnsr.getType())) - return mlir::failure(); - - auto *converter = getTypeConverter(); - assert(converter && "Type converter is not set"); - auto dstType = converter->convertType(op.getType()); - if (!dstType || dstType != srcArType.getElementType()) - return mlir::failure(); - - rewriter.replaceOpWithNewOp(op, srcTnsr, - adaptor.getIndices()); - - return mlir::success(); - } -}; - /// Convert NDArray's insert_slice to memref struct InsertSliceLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::InsertSliceOp> { - using OpConversionPattern::OpConversionPattern; + : public ::mlir::OpRewritePattern<::imex::ndarray::InsertSliceOp> { + using OpRewritePattern::OpRewritePattern; ::mlir::LogicalResult matchAndRewrite(::imex::ndarray::InsertSliceOp op, - ::imex::ndarray::InsertSliceOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { + ::mlir::PatternRewriter &rewriter) const override { auto loc = op.getLoc(); // get operators - auto src = adaptor.getSource(); - auto dst = adaptor.getDestination(); + auto src = op.getSource(); + auto dst = op.getDestination(); auto srcTyp = mlir::dyn_cast<::mlir::ShapedType>(src.getType()); auto dstTyp = mlir::dyn_cast<::mlir::ShapedType>(dst.getType()); if (!dstTyp || !srcTyp) @@ -313,12 +281,12 @@ struct InsertSliceLowering mlir::Value srcMR = createToMemRef(loc, rewriter, src, srcMRTyp); auto dstMR = createToMemRef(loc, rewriter, dst, dstMRTyp); - auto slcOffs = ::mlir::getMixedValues(adaptor.getStaticOffsets(), - adaptor.getOffsets(), rewriter); - auto slcSizes = ::mlir::getMixedValues(adaptor.getStaticSizes(), - adaptor.getSizes(), rewriter); - auto slcStrides = ::mlir::getMixedValues(adaptor.getStaticStrides(), - adaptor.getStrides(), rewriter); + auto slcOffs = ::mlir::getMixedValues(op.getStaticOffsets(), + op.getOffsets(), rewriter); + auto slcSizes = + ::mlir::getMixedValues(op.getStaticSizes(), op.getSizes(), rewriter); + auto slcStrides = ::mlir::getMixedValues(op.getStaticStrides(), + op.getStrides(), rewriter); auto view = rewriter.create<::mlir::memref::SubViewOp>( loc, dstMR, slcOffs, slcSizes, slcStrides); @@ -352,26 +320,24 @@ struct InsertSliceLowering /// Convert immutable_insert_slice to tensor struct ImmutableInsertSliceLowering - : public ::mlir::OpConversionPattern< - ::imex::ndarray::ImmutableInsertSliceOp> { - using OpConversionPattern::OpConversionPattern; + : public ::mlir::OpRewritePattern<::imex::ndarray::ImmutableInsertSliceOp> { + using OpRewritePattern::OpRewritePattern; ::mlir::LogicalResult matchAndRewrite(::imex::ndarray::ImmutableInsertSliceOp op, - ::imex::ndarray::ImmutableInsertSliceOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { + ::mlir::PatternRewriter &rewriter) const override { auto loc = op.getLoc(); // get operators - auto src = adaptor.getSource(); - auto dst = adaptor.getDestination(); + auto src = op.getSource(); + auto dst = op.getDestination(); - auto offsets = ::mlir::getMixedValues(adaptor.getStaticOffsets(), - adaptor.getOffsets(), rewriter); - auto sizes = ::mlir::getMixedValues(adaptor.getStaticSizes(), - adaptor.getSizes(), rewriter); - auto strides = ::mlir::getMixedValues(adaptor.getStaticStrides(), - adaptor.getStrides(), rewriter); + auto offsets = ::mlir::getMixedValues(op.getStaticOffsets(), + op.getOffsets(), rewriter); + auto sizes = + ::mlir::getMixedValues(op.getStaticSizes(), op.getSizes(), rewriter); + auto strides = ::mlir::getMixedValues(op.getStaticStrides(), + op.getStrides(), rewriter); auto slice = rewriter.create<::mlir::tensor::InsertSliceOp>( loc, src, dst, offsets, sizes, strides); @@ -384,20 +350,19 @@ struct ImmutableInsertSliceLowering /// Convert NDArray's linspace and its return type to Linalg/tensor. /// Also needs some arith and affine (for linalg::genericop). struct LinSpaceLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::LinSpaceOp> { - using OpConversionPattern::OpConversionPattern; + : public ::mlir::OpRewritePattern<::imex::ndarray::LinSpaceOp> { + using OpRewritePattern::OpRewritePattern; ::mlir::LogicalResult matchAndRewrite(::imex::ndarray::LinSpaceOp op, - ::imex::ndarray::LinSpaceOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { + ::mlir::PatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto start = adaptor.getStart(); - auto stop = adaptor.getStop(); - auto count = adaptor.getNum(); - bool endpoint = adaptor.getEndpoint(); - auto retArTyp = mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getType()); + auto start = op.getStart(); + auto stop = op.getStop(); + auto count = op.getNum(); + bool endpoint = op.getEndpoint(); + auto retArTyp = mlir::dyn_cast<::mlir::RankedTensorType>(op.getType()); auto rank = retArTyp.getRank(); auto elTyp = retArTyp.getElementType(); @@ -447,27 +412,25 @@ struct LinSpaceLowering /// Convert NDArray's createOp and its return type to Linalg/tensor. /// Also needs some arith and affine (for linalg::genericop). struct CreateLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::CreateOp> { - using OpConversionPattern::OpConversionPattern; + : public ::mlir::OpRewritePattern<::imex::ndarray::CreateOp> { + using OpRewritePattern::OpRewritePattern; ::mlir::LogicalResult matchAndRewrite(::imex::ndarray::CreateOp op, - ::imex::ndarray::CreateOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { + ::mlir::PatternRewriter &rewriter) const override { auto loc = op.getLoc(); // check output type and get operands - auto retArTyp = mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getType()); + auto retArTyp = mlir::dyn_cast<::mlir::RankedTensorType>(op.getType()); if (!retArTyp) return ::mlir::failure(); - auto value = adaptor.getValue(); + auto value = op.getValue(); // init tensor auto elTyp = retArTyp.getElementType(); - ::mlir::Value res = - createEmptyTensor(rewriter, loc, elTyp, adaptor.getShape()); + ::mlir::Value res = createEmptyTensor(rewriter, loc, elTyp, op.getShape()); - if (!retArTyp.hasZeroSize() && value) { + if (!ndarray::hasZeroSize(retArTyp.getShape()) && value) { res = createParFor( loc, rewriter, retArTyp.getRank(), res, ::mlir::ValueRange(), [&value](::mlir::OpBuilder &builder, ::mlir::Location loc, @@ -482,83 +445,24 @@ struct CreateLowering } }; -/// Convert ndarray.copy and its return type to memref.alloc + memref.copy. -struct CopyLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::CopyOp> { - using OpConversionPattern::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::CopyOp op, - ::imex::ndarray::CopyOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - // check output type and get operands - auto srcArTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getSource().getType()); - auto retArTyp = mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getType()); - if (!(srcArTyp && retArTyp)) { - return ::mlir::failure(); - } - - auto loc = op.getLoc(); - auto src = adaptor.getSource(); - auto rank = srcArTyp.getRank(); - ::imex::ValVec dynDims; - - // get dynamic shape - auto tTyp = mlir::cast<::mlir::TensorType>(src.getType()); - for (int64_t i = 0; i < rank; ++i) { - if (tTyp.isDynamicDim(i)) { - dynDims.emplace_back( - rewriter.createOrFold<::mlir::tensor::DimOp>(loc, src, i)); - } - } - // alloc memref - auto mrTyp = - ::mlir::MemRefType::get(tTyp.getShape(), tTyp.getElementType()); - auto mr = rewriter.create<::mlir::memref::AllocOp>( - loc, mrTyp, dynDims, rewriter.getI64IntegerAttr(8)); - // and copy if non-0 - if (!retArTyp.hasZeroSize()) { - auto srcMR = - createToMemRef(loc, rewriter, src, srcArTyp.getMemRefType(src)); - // create a region with given env, add copy op within it - auto env = rewriter.getStringAttr("protect_copy_op"); - rewriter.create<::imex::region::EnvironmentRegionOp>( - loc, env, std::nullopt, std::nullopt, - [&srcMR, &mr](::mlir::OpBuilder &builder, ::mlir::Location loc) { - (void)builder.create<::mlir::memref::CopyOp>(loc, srcMR, mr); - (void)builder.create<::imex::region::EnvironmentRegionYieldOp>(loc); - }); - } - // convert memref to tensor - auto res = rewriter.create<::mlir::bufferization::ToTensorOp>( - loc, retArTyp.getTensorType(), mr, /*restrict=*/true, - /*writable=*/true); - rewriter.replaceOp(op, res); - - return ::mlir::success(); - } -}; - /// Convert ndarray.delete and its return type to memref.dealloc. struct DeleteLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::DeleteOp> { - using OpConversionPattern::OpConversionPattern; + : public ::mlir::OpRewritePattern<::imex::ndarray::DeleteOp> { + using OpRewritePattern::OpRewritePattern; ::mlir::LogicalResult matchAndRewrite(::imex::ndarray::DeleteOp op, - ::imex::ndarray::DeleteOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { + ::mlir::PatternRewriter &rewriter) const override { // check output type and get operands auto inpArType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getInput().getType()); + mlir::dyn_cast<::mlir::RankedTensorType>(op.getInput().getType()); if (!inpArType) { return ::mlir::failure(); } - auto inp = adaptor.getInput(); - auto inpMR = createToMemRef(op.getLoc(), rewriter, inp, - inpArType.getMemRefType(inp)); + auto inp = op.getInput(); + auto inpMR = + createToMemRef(op.getLoc(), rewriter, inp, getMemRefType(inpArType)); auto newOp = rewriter.replaceOpWithNewOp<::mlir::memref::DeallocOp>(op, inpMR); newOp->setAttrs(op->getAttrs()); @@ -569,18 +473,17 @@ struct DeleteLowering /// Convert ndarray.cast_elemtype to linalg.generic with cast struct CastElemTypeLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::CastElemTypeOp> { - using OpConversionPattern::OpConversionPattern; + : public ::mlir::OpRewritePattern<::imex::ndarray::CastElemTypeOp> { + using OpRewritePattern::OpRewritePattern; ::mlir::LogicalResult matchAndRewrite(::imex::ndarray::CastElemTypeOp op, - ::imex::ndarray::CastElemTypeOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { + ::mlir::PatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto src = adaptor.getInput(); + auto src = op.getInput(); auto srcArType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getInput().getType()); - auto dstArType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getType()); + mlir::dyn_cast<::mlir::RankedTensorType>(op.getInput().getType()); + auto dstArType = mlir::dyn_cast<::mlir::RankedTensorType>(op.getType()); if (!(srcArType && dstArType)) { return ::mlir::failure(); } @@ -589,23 +492,18 @@ struct CastElemTypeLowering assert(dstArType.getRank() == srcArType.getRank()); assert(dstArType.getShape() == srcArType.getShape()); - auto srcType = srcArType.getTensorType(); - auto dstType = dstArType.getTensorType(); - auto dstElType = dstType.getElementType(); - - auto rank = srcType.getRank(); + auto dstElType = dstArType.getElementType(); + auto rank = srcArType.getRank(); auto map = rewriter.getMultiDimIdentityMap(rank); ::mlir::SmallVector iterators( rank, ::mlir::utils::IteratorType::parallel); // identical types if (srcArType == dstArType) { - if (adaptor.getCopy().value_or(false)) { + if (op.getCopy().value_or(false)) { // emit a copy op - auto arSrc = op.getInput(); - auto copyOp = - rewriter.create<::imex::ndarray::CopyOp>(loc, dstArType, arSrc); - rewriter.replaceOp(op, copyOp.getResult()); + rewriter.replaceOpWithNewOp<::imex::ndarray::CopyOp>(op, dstArType, + src); return ::mlir::success(); } else { // eliminate cast op @@ -615,9 +513,9 @@ struct CastElemTypeLowering } } - auto dst = createEmptyTensor(rewriter, loc, dstType, src); + auto dst = createEmptyTensor(rewriter, loc, dstArType, src); auto cast = rewriter.create<::mlir::linalg::GenericOp>( - loc, dstType, src, dst, ::mlir::ArrayRef({map, map}), iterators, + loc, dstArType, src, dst, ::mlir::ArrayRef({map, map}), iterators, [dstElType](::mlir::OpBuilder &b, ::mlir::Location loc, ::mlir::ValueRange args) { auto val = createCast(loc, b, args[0], dstElType); @@ -629,622 +527,6 @@ struct CastElemTypeLowering } }; -/// Convert NDArray's ReshapeOp and its return type to Linalg/tensor. -/// Optionally creates a copy first. -struct ReshapeLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::ReshapeOp> { - using OpConversionPattern::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::ReshapeOp op, - ::imex::ndarray::ReshapeOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - // check output type and get operands - auto retArTyp = mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getType()); - auto srcArTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getSource().getType()); - if (!(retArTyp && srcArTyp)) { - return ::mlir::failure(); - } - - auto loc = op.getLoc(); - auto src = adaptor.getSource(); - auto shape = adaptor.getShape(); - auto outTyp = retArTyp.getTensorType(); - - if (adaptor.getCopy().value_or(false)) { - auto arSrc = op.getSource(); - auto copyOp = - rewriter.create<::imex::ndarray::CopyOp>(loc, srcArTyp, arSrc); - auto toTensorOp = - rewriter.create<::imex::ndarray::ToTensorOp>(loc, copyOp.getResult()); - src = toTensorOp.getResult(); - } - - auto shapeT = rewriter.create<::mlir::tensor::FromElementsOp>(loc, shape); - rewriter.replaceOpWithNewOp<::mlir::tensor::ReshapeOp>(op, outTyp, src, - shapeT); - - return ::mlir::success(); - } -}; - -// function type for building body for linalg::generic -using BodyType = std::function; - -// any genericOp body needs to close with a yield -// we also add a cast op to "typ" if needed -template -static void yield(mlir::OpBuilder &builder, ::mlir::Location loc, - ::mlir::Type typ, T val) { - auto res = val; - if (typ != res.getType()) { - res = builder.create<::mlir::UnrealizedConversionCastOp>(loc, typ, res) - .getResult(0); - } - (void)builder.create(loc, res); -} - -/// Trivial binop builders have simple equivalents in Arith. -/// The Arith ops are accepted as template arguments, one for ints and one for -/// floats. Currently only integers and floats are supported. -/// Currently unsigned int ops are not supported. -template -static BodyType buildTrivialBinary(::mlir::Type typ) { - return [typ](mlir::OpBuilder &builder, ::mlir::Location loc, - ::mlir::ValueRange args) -> void { - auto lhs = createCast(loc, builder, args[0], typ); - auto rhs = createCast(loc, builder, args[1], typ); - if (typ.isIntOrIndex()) { - if constexpr (!std::is_same_v) { - yield(builder, loc, typ, - builder.create(loc, lhs, rhs).getResult()); - return; - } else - assert(0 && - "Found integer type but binary op not defined for integers"); - } else if (typ.isIntOrIndexOrFloat()) { - if constexpr (!std::is_same_v) { - yield(builder, loc, typ, - builder.create(loc, lhs, rhs).getResult()); - return; - } else - assert(0 && "Found float type but binary op not defined for floats"); - } else { - assert(0 && "Only integers and floats supported for binary ops"); - } - }; -} - -static BodyType buildNegative(::mlir::Type typ) { - return [typ](mlir::OpBuilder &builder, ::mlir::Location loc, - ::mlir::ValueRange args) -> void { - mlir::TypedAttr minus; - if (typ.isUnsignedInteger()) { - assert(0 && "Unsigned integers are not supported in negative op"); - } else if (typ.isIntOrIndex()) { - minus = builder.getIntegerAttr(typ, -1); - } else if (typ.isIntOrIndexOrFloat()) { - minus = builder.getFloatAttr(typ, -1); - } else { - assert(0 && "Only integers and floats are supported"); - } - // Emit a trivial multiply binop with a constant scalar -1 - auto scalar = builder.create<::mlir::arith::ConstantOp>(loc, typ, minus); - auto mulOp = - buildTrivialBinary(typ); - mulOp(builder, loc, {args[0], scalar}); - }; -} - -/// Trivial unary op builders have simple equivalents in Math. -/// The Math ops are accepted as template arguments, one for ints and one for -/// floats. Currently only integers and floats are supported. -/// Currently unsigned int ops are not supported. -template -static BodyType buildTrivialUnary(::mlir::Type typ) { - return [typ](mlir::OpBuilder &builder, ::mlir::Location loc, - ::mlir::ValueRange args) -> void { - auto srcTyp = args[0].getType(); - if (srcTyp.isIntOrIndex()) { - if constexpr (!std::is_same_v) { - auto src = doSignCast(builder, loc, args[0]); - yield(builder, loc, typ, builder.create(loc, src).getResult()); - return; - } else - assert(0 && - "Found integer type but binary op not defined for integers"); - } else if (srcTyp.isIntOrIndexOrFloat()) { - if constexpr (!std::is_same_v) { - yield(builder, loc, typ, builder.create(loc, args[0]).getResult()); - return; - } else - assert(0 && "Found float type but binary op not defined for floats"); - } else { - assert(0 && "Only integers and floats supported for binary ops"); - } - }; -} - -/// get a body builder for given binary operation and result type. -/// Accepts a result type to insert a cast after the operation if needed -/// FIXME: add missing ops -static BodyType getBodyBuilder(::imex::ndarray::EWBinOpId binOp, - ::mlir::Type typ) { - switch (binOp) { - case ndarray::ADD: - return buildTrivialBinary(typ); - case ndarray::ATAN2: - return buildTrivialBinary(typ); - case ndarray::FLOOR_DIVIDE: - return buildTrivialBinary(typ); - // case ndarray::LOGADDEXP] = - // case ndarray::MATMUL] = - case ndarray::MAXIMUM: - return buildTrivialBinary( - typ); - case ndarray::MINIMUM: - return buildTrivialBinary( - typ); - case ndarray::MODULO: - return buildTrivialBinary(typ); - case ndarray::MULTIPLY: - return buildTrivialBinary(typ); - case ndarray::POWER: - return buildTrivialBinary(typ); - case ndarray::SUBTRACT: - return buildTrivialBinary(typ); - case ndarray::TRUE_DIVIDE: - return buildTrivialBinary<::mlir::arith::DivSIOp, ::mlir::arith::DivFOp>( - typ); - // case ndarray::BITWISE_LEFT_SHIFT] = - // case ndarray::BITWISE_RIGHT_SHIFT] = - - // case ndarray::EQUAL] = - // case ndarray::GREATER] = - // case ndarray::GREATER_EQUAL] = - // case ndarray::LESS] = - // case ndarray::LESS_EQUAL] = - // case ndarray::NOT_EQUAL] = - default: - assert(0 && "unsupported elementwise binary operation"); - }; -} - -::mlir::Value createTosaOp(::mlir::Location loc, - ::imex::ndarray::EWBinOpId binOpId, - ::mlir::ConversionPatternRewriter &rewriter, - ::mlir::TensorType returnType, ::mlir::Value lhs, - ::mlir::Value rhs) { - switch (binOpId) { - case ndarray::BITWISE_AND: - return rewriter - .create<::mlir::tosa::BitwiseAndOp>(loc, returnType, lhs, rhs) - .getResult(); - case ndarray::BITWISE_OR: - return rewriter.create<::mlir::tosa::BitwiseOrOp>(loc, returnType, lhs, rhs) - .getResult(); - case ndarray::BITWISE_XOR: - return rewriter - .create<::mlir::tosa::BitwiseXorOp>(loc, returnType, lhs, rhs) - .getResult(); - case ndarray::LOGICAL_AND: - return rewriter - .create<::mlir::tosa::LogicalAndOp>(loc, returnType, lhs, rhs) - .getResult(); - case ndarray::LOGICAL_OR: - return rewriter.create<::mlir::tosa::LogicalOrOp>(loc, returnType, lhs, rhs) - .getResult(); - case ndarray::LOGICAL_XOR: - return rewriter - .create<::mlir::tosa::LogicalXorOp>(loc, returnType, lhs, rhs) - .getResult(); - default: - break; - }; - return ::mlir::Value(); -} - -/// Convert NDArray's elementwise binary operations and their return type to -/// Linalg/tensor. The given op's type is expected to convert to the appropriate -/// type (shape and element-type). -/// Also needs some arith and affine (for linalg::genericop). -struct EWBinOpLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::EWBinOp> { - using OpConversionPattern::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::EWBinOp op, - ::imex::ndarray::EWBinOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - // We expect to lower NDArrays - auto lhsArTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getLhs().getType()); - auto rhsArTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getRhs().getType()); - if (!lhsArTyp || !rhsArTyp) { - return ::mlir::failure(); - } - - auto resType = - mlir::cast<::imex::ndarray::NDArrayType>(op->getResult(0).getType()) - .getTensorType(); - // we assume the result type has been correctly promoted - auto elTyp = resType.getElementType(); - - // get the input as tensors - auto lhs = adaptor.getLhs(); - auto rhs = adaptor.getRhs(); - auto lhsTnsr = mlir::cast<::mlir::TensorType>(lhs.getType()); - auto rhsTnsr = mlir::cast<::mlir::TensorType>(rhs.getType()); - - // we expect tensor types on operands - auto lhsRank = lhsTnsr.getRank(); - auto rhsRank = rhsTnsr.getRank(); - - auto rank = static_cast(std::max(lhsRank, rhsRank)); - - const ::imex::ndarray::EWBinOpId binOpId = - (::imex::ndarray::EWBinOpId)mlir::cast<::mlir::IntegerAttr>( - adaptor.getOp()) - .getInt(); - - ::mlir::Value newOp = - createTosaOp(loc, binOpId, rewriter, resType, lhs, rhs); - if (!newOp) { - // generate linalg.generic loop - - // create output tensor with right dimensions - auto tensor = createEmptyTensor(rewriter, loc, resType, {lhs, rhs}); - - // we need affine maps for linalg::generic - // as long as we have no proper support for rank-reduced sizes above - // Linalg, we can handle only - // - explicitly rank-reduced inputs (such as explicit 0d tensors) - // - shapes with static dim-sizes of 1 - ::mlir::SmallVector<::mlir::AffineExpr> lhsExprs, rhsExprs, resExprs; - for (int i = 0; i < lhsRank; ++i) { - lhsExprs.emplace_back(lhsTnsr.getDimSize(i) == 1 - ? rewriter.getAffineConstantExpr(0) - : rewriter.getAffineDimExpr(i)); - } - for (int i = 0; i < rhsRank; ++i) { - rhsExprs.emplace_back(rhsTnsr.getDimSize(i) == 1 - ? rewriter.getAffineConstantExpr(0) - : rewriter.getAffineDimExpr(i)); - } - for (unsigned i = 0; i < rank; ++i) { - resExprs.emplace_back(rewriter.getAffineDimExpr(i)); - } - auto lhsMap = ::mlir::AffineMap::get(resType.getRank(), /*symbolCount=*/0, - lhsExprs, rewriter.getContext()); - auto rhsMap = ::mlir::AffineMap::get(resType.getRank(), /*symbolCount=*/0, - rhsExprs, rewriter.getContext()); - auto resMap = rewriter.getMultiDimIdentityMap(resType.getRank()); - - // we just make all dims parallel - ::mlir::SmallVector iterators( - rank, ::mlir::utils::IteratorType::parallel); - - // get the body builder for our binop and create genericop - // FIXME: make createParFor ready for this - auto bodyBuilder = getBodyBuilder(binOpId, elTyp); - newOp = - rewriter - .create<::mlir::linalg::GenericOp>( - loc, tensor.getType(), ::mlir::ValueRange{lhs, rhs}, tensor, - ::mlir::ArrayRef<::mlir::AffineMap>{lhsMap, rhsMap, resMap}, - iterators, bodyBuilder) - .getResult(0); - } - rewriter.replaceOp(op, newOp); - - return ::mlir::success(); - } -}; - -/// get a body builder for given binary operation and result type. -/// Accepts a result type to insert a cast after the operation if needed -/// FIXME: add missing ops -static BodyType getBodyBuilder(::imex::ndarray::EWUnyOpId binOp, - ::mlir::Type typ) { - switch (binOp) { - case ndarray::ABS: - return buildTrivialUnary<::mlir::math::AbsIOp, ::mlir::math::AbsFOp>(typ); - case ndarray::ATAN: - return buildTrivialUnary(typ); - case ndarray::CEIL: - return buildTrivialUnary(typ); - case ndarray::COS: - return buildTrivialUnary(typ); - case ndarray::ERF: - return buildTrivialUnary(typ); - case ndarray::EXP: - return buildTrivialUnary(typ); - case ndarray::EXPM1: - return buildTrivialUnary(typ); - case ndarray::FLOOR: - return buildTrivialUnary(typ); - case ndarray::LOG: - return buildTrivialUnary(typ); - case ndarray::LOG1P: - return buildTrivialUnary(typ); - case ndarray::LOG2: - return buildTrivialUnary(typ); - case ndarray::LOG10: - return buildTrivialUnary(typ); - case ndarray::ROUND: - return buildTrivialUnary(typ); - case ndarray::SIN: - return buildTrivialUnary(typ); - case ndarray::SQRT: - return buildTrivialUnary(typ); - case ndarray::TAN: - return buildTrivialUnary(typ); - case ndarray::TANH: - return buildTrivialUnary(typ); - case ndarray::TRUNC: - return buildTrivialUnary(typ); - case ndarray::NEGATIVE: - return buildNegative(typ); - default: - assert(0 && "unsupported elementwise binary operation"); - }; -} - -/// Lower unary operations which are not natively provided in any of the MLIR -/// dialects. -/// @return resulting non-null value if the operation was lowered, null-value -/// otherwise -::mlir::Value createAggUnaryOp(::mlir::Location loc, - ::imex::ndarray::EWUnyOpId unyOpId, - ::mlir::ConversionPatternRewriter &rewriter, - ::imex::ndarray::NDArrayType returnType, - ::mlir::Value src) { - switch (unyOpId) { - case ndarray::SQUARE: - return rewriter - .create<::imex::ndarray::EWBinOp>( - loc, returnType, - getIntAttr(rewriter, ::imex::ndarray::MULTIPLY, 32), src, src) - .getResult(); - default: - break; - }; - return ::mlir::Value(); -} - -/// Lower unary operations which are provided only by TOSA (and not by math or -/// arith). -/// @return resulting non-null value if the operation was lowered, null-value -/// otherwise -::mlir::Value createUnaryTosaOp(::mlir::Location loc, - ::imex::ndarray::EWUnyOpId unyOpId, - ::mlir::ConversionPatternRewriter &rewriter, - ::mlir::TensorType returnType, - ::mlir::Value src) { - switch (unyOpId) { - case ndarray::LOGICAL_NOT: - return rewriter.create(loc, returnType, src) - .getResult(); - default: - break; - }; - return ::mlir::Value(); -} - -/// Convert NDArray's elementwise unary operations and their return type to -/// Linalg/tensor. The given op's type is expected to convert to the appropriate -/// type (shape and element-type). -/// Also needs some arith and affine (for linalg::genericop). -struct EWUnyOpLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::EWUnyOp> { - using OpConversionPattern::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::EWUnyOp op, - ::imex::ndarray::EWUnyOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto arSrc = op.getSrc(); - // We expect to lower NDArrays - auto srcArTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(arSrc.getType()); - if (!srcArTyp) { - // FIXME type casting - return ::mlir::failure(); - } - - const ::imex::ndarray::EWUnyOpId unyOpId = - (::imex::ndarray::EWUnyOpId)mlir::cast<::mlir::IntegerAttr>( - adaptor.getOp()) - .getInt(); - if (unyOpId == ::imex::ndarray::POSITIVE) { - // positive unary op is a no-op, remove it - rewriter.replaceAllUsesWith(op.getResult(), op.getSrc()); - rewriter.eraseOp(op); - return ::mlir::success(); - } - - auto resArType = - mlir::cast<::imex::ndarray::NDArrayType>(op->getResult(0).getType()); - - // generic lowering of non-MLIR-native ops - ::mlir::Value newOp = - createAggUnaryOp(loc, unyOpId, rewriter, resArType, arSrc); - - if (!newOp) { // not lowered yet - // get the input/output tensor types - auto src = adaptor.getSrc(); - auto srcTnsr = mlir::cast<::mlir::TensorType>(src.getType()); - auto resType = resArType.getTensorType(); - - // we expect tensor types on operands - auto elTyp = srcTnsr.getElementType(); - auto rank = srcTnsr.getRank(); - - // try to lower to TOSA - newOp = createUnaryTosaOp(loc, unyOpId, rewriter, resType, src); - - if (!newOp) { // still not lowered: generate linalg.generic loop - // create output tensor with right dimensions - auto tensor = createEmptyTensor(rewriter, loc, resType, {src}); - - // we need affine maps for linalg::generic - const ::mlir::AffineMap map = ::mlir::AffineMap::getMultiDimIdentityMap( - rank, rewriter.getContext()); - ::mlir::SmallVector<::mlir::AffineMap> maps(2, map); - // we just make all dims parallel - ::mlir::SmallVector iterators( - rank, ::mlir::utils::IteratorType::parallel); - - // get the body builder for our binop and create genericop - // FIXME: make createParFor ready for this - auto bodyBuilder = getBodyBuilder(unyOpId, elTyp); - newOp = rewriter - .create<::mlir::linalg::GenericOp>( - loc, tensor.getType(), ::mlir::ValueRange{src}, tensor, - maps, iterators, bodyBuilder) - .getResult(0); - } - } - - rewriter.replaceOp(op, newOp); - - return ::mlir::success(); - } -}; - -// get a body builder for given binary operation and result type -// we accept a result type to insert a cast after the operation if needed -static BodyType getBodyBuilder(::imex::ndarray::ReduceOpId redOp, - ::mlir::Type typ) { - switch (redOp) { - case ::imex::ndarray::PROD: - return getBodyBuilder(::imex::ndarray::MULTIPLY, typ); - case ::imex::ndarray::SUM: - return getBodyBuilder(::imex::ndarray::ADD, typ); - case ::imex::ndarray::MAX: - return getBodyBuilder(::imex::ndarray::MAXIMUM, typ); - case ::imex::ndarray::MIN: - return getBodyBuilder(::imex::ndarray::MINIMUM, typ); - case ::imex::ndarray::MEAN: - case ::imex::ndarray::STD: - case ::imex::ndarray::VAR: - default: - assert(0 && "unsupported reduction operation"); - }; -} - -/// Convert NDArray's reduction operations and their return type to -/// Linalg/tensor. The given op's type is expected to convert to the appropriate -/// type (shape and element-type). Also needs some arith and affine (for -/// linalg::genericop). -// FIXME reduction over a subset of dimensionsstruct ReductionOpLowering -struct ReductionOpLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::ReductionOp> { - using OpConversionPattern::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::ReductionOp op, - ::imex::ndarray::ReductionOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - // We expect to lower NDArrays - auto inpArTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getInput().getType()); - if (!inpArTyp) { - // fail if not, will be retried if operands get converted elsewhere - return ::mlir::failure(); - } - - // we expect tensorType as operands - auto inpTnsr = adaptor.getInput(); - auto inpTnsrTyp = mlir::cast<::mlir::TensorType>(inpTnsr.getType()); - - // Get signless operands into vec - ::mlir::SmallVector oprnds = {inpTnsr}; - - // determine resulting element type from converted op-type - auto retArTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getResult().getType()); - assert(retArTyp); - auto retTyp = retArTyp.getTensorType(); - auto elTyp = retTyp.getElementType(); - auto sElTyp = makeSignlessType(elTyp); - - // build tensor using the resulting element type and shape - // FIXME support reduction dimensions - auto rank = static_cast(retTyp.getRank()); - assert(rank == 0); - auto zeroI = createIndex(loc, rewriter, 0); - ::imex::ValVec shapeVVec(rank, zeroI); - // create new tensor - auto zero = createInt(loc, rewriter, 0); - auto tensor = createEmptyTensor(rewriter, loc, sElTyp, shapeVVec); - auto tnsr = rewriter.create<::mlir::linalg::FillOp>(loc, zero, tensor); - - // rank/num-dims of input - auto inpRank = static_cast(inpTnsrTyp.getRank()); - // input maps are identity maps - auto inpMap = ::mlir::AffineMap::getMultiDimIdentityMap( - inpRank, rewriter.getContext()); - // output map is "*->()" - auto omap = ::mlir::AffineMap::get(inpRank, 0, rewriter.getContext()); - const ::mlir::AffineMap maps[] = {inpMap, omap}; - ::mlir::SmallVector iterators( - inpRank, mlir::utils::IteratorType::reduction); - - // create reduction op as linalg::generic - const ::imex::ndarray::ReduceOpId ropid = - (::imex::ndarray::ReduceOpId)mlir::cast<::mlir::IntegerAttr>( - adaptor.getOp()) - .getInt(); - auto bodyBuilder = getBodyBuilder(ropid, sElTyp); - auto resTnsr = rewriter.create<::mlir::linalg::GenericOp>( - loc, tnsr.getType(0), oprnds, tnsr.getResult(0), maps, iterators, - bodyBuilder); - rewriter.replaceOp(op, resTnsr.getResult(0)); - - return ::mlir::success(); - } -}; - -/// Convert NDArray's permute_dims operations and their return type to -/// Linalg/tensor. -struct PermuteDimsOpLowering - : public ::mlir::OpConversionPattern<::imex::ndarray::PermuteDimsOp> { - using OpConversionPattern::OpConversionPattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::PermuteDimsOp op, - ::imex::ndarray::PermuteDimsOp::Adaptor adaptor, - ::mlir::ConversionPatternRewriter &rewriter) const override { - - auto loc = op->getLoc(); - auto srcTnsr = adaptor.getSource(); - - // convert src array to memref - auto srcArType = mlir::dyn_cast_or_null<::imex::ndarray::NDArrayType>( - op.getSource().getType()); - if (!srcArType) - return mlir::failure(); - auto srcMRType = srcArType.getMemRefType(srcTnsr); - auto srcMR = createToMemRef(loc, rewriter, srcTnsr, srcMRType, true); - - auto perm = ::mlir::AffineMapAttr::get(::mlir::AffineMap::getPermutationMap( - adaptor.getAxes(), rewriter.getContext())); - mlir::memref::TransposeOp transposeOp = - rewriter.create(loc, srcMR, perm); - - rewriter.replaceOp(op, transposeOp.getResult()); - - return ::mlir::success(); - } -}; - // ******************************* // ***** Pass infrastructure ***** // ******************************* @@ -1260,113 +542,112 @@ struct ConvertNDArrayToLinalgPass void runOnOperation() override { auto &ctxt = getContext(); - ::mlir::TypeConverter typeConverter; - // Convert unknown types to itself - auto convT2T = [](::mlir::Type type) { return type; }; - // Convert NDArrayType to (tensorType) - auto convNDArray2RankedTensor = - [](::imex::ndarray::NDArrayType type) -> std::optional<::mlir::Type> { - return type.getTensorType(); - }; - - typeConverter.addConversion(convT2T); - typeConverter.addConversion(convNDArray2RankedTensor); - - auto materializeCast = - [](::mlir::OpBuilder &builder, ::mlir::Type type, - ::mlir::ValueRange inputs, - ::mlir::Location loc) -> ::mlir::Value { - if (inputs.size() == 1) { - auto input = inputs[0]; - auto itype = input.getType(); - if (mlir::isa<::mlir::TensorType>(type) and - mlir::isa<::mlir::TensorType>(itype)) { - return builder.create<::mlir::tensor::CastOp>(loc, type, inputs) - .getResult(); - } - auto ttype = mlir::dyn_cast<::mlir::RankedTensorType>(itype); - if (ttype && mlir::isa<::mlir::MemRefType>(type)) { - return createToMemRef(loc, builder, input, type); - } - auto mrtype = mlir::dyn_cast<::mlir::MemRefType>(itype); - if (mrtype && mlir::isa<::mlir::RankedTensorType>(type)) { - return builder - .create<::mlir::bufferization::ToTensorOp>(loc, type, input, - /*restrict=*/true) - .getResult(); - } - } - return builder - .create<::mlir::UnrealizedConversionCastOp>(loc, type, inputs) - .getResult(0); - }; - typeConverter.addSourceMaterialization(materializeCast); - typeConverter.addTargetMaterialization(materializeCast); - - // At function boundaries we have actual memref semantics. - // We need to explicitly convert in/out arguments to memrefs. - // If we use tensors downstream passes will auto-convert to non-strided - // memrefs which will imply a copy (converting from strided to non-strided - // requires a copy) - // We simply use a separate type-converter and materializations - - ::mlir::TypeConverter typeConverter4Func; - // Convert NDArrayType to MemRefType - auto convNDArray2MemRef = - [](::imex::ndarray::NDArrayType type) -> std::optional<::mlir::Type> { - return type.getMemRefType(); - }; - - typeConverter4Func.addConversion(convT2T); - typeConverter4Func.addConversion(convNDArray2MemRef); - typeConverter4Func.addSourceMaterialization(materializeCast); - typeConverter4Func.addTargetMaterialization(materializeCast); + // ::mlir::TypeConverter typeConverter; + // // Convert unknown types to itself + // auto convT2T = [](::mlir::Type type) { return type; }; + // // Convert NDArrayType to (tensorType) + // auto convNDArray2RankedTensor = + // [](::imex::ndarray::NDArrayType type) -> std::optional<::mlir::Type> + // { + // return type.getTensorType(); + // }; + + // typeConverter.addConversion(convT2T); + // typeConverter.addConversion(convNDArray2RankedTensor); + + // auto materializeCast = + // [](::mlir::OpBuilder &builder, ::mlir::Type type, + // ::mlir::ValueRange inputs, + // ::mlir::Location loc) -> ::mlir::Value { + // if (inputs.size() == 1) { + // auto input = inputs[0]; + // auto itype = input.getType(); + // if (mlir::isa<::mlir::TensorType>(type) and + // mlir::isa<::mlir::TensorType>(itype)) { + // return builder.create<::mlir::tensor::CastOp>(loc, type, inputs) + // .getResult(); + // } + // auto ttype = mlir::dyn_cast<::mlir::RankedTensorType>(itype); + // if (ttype && mlir::isa<::mlir::MemRefType>(type)) { + // return createToMemRef(loc, builder, input, type); + // } + // auto mrtype = mlir::dyn_cast<::mlir::MemRefType>(itype); + // if (mrtype && mlir::isa<::mlir::RankedTensorType>(type)) { + // return builder + // .create<::mlir::bufferization::ToTensorOp>(loc, type, input, + // /*restrict=*/true) + // .getResult(); + // } + // } + // return builder + // .create<::mlir::UnrealizedConversionCastOp>(loc, type, inputs) + // .getResult(0); + // }; + // typeConverter.addSourceMaterialization(materializeCast); + // typeConverter.addTargetMaterialization(materializeCast); + + // // At function boundaries we have actual memref semantics. + // // We need to explicitly convert in/out arguments to memrefs. + // // If we use tensors downstream passes will auto-convert to non-strided + // // memrefs which will imply a copy (converting from strided to + // non-strided + // // requires a copy) + // // We simply use a separate type-converter and materializations + + // ::mlir::TypeConverter typeConverter4Func; + // // Convert NDArrayType to MemRefType + // auto convNDArray2MemRef = + // [](::imex::ndarray::NDArrayType type) -> std::optional<::mlir::Type> + // { + // return type.getMemRefType(); + // }; + + // typeConverter4Func.addConversion(convT2T); + // typeConverter4Func.addConversion(convNDArray2MemRef); + // typeConverter4Func.addSourceMaterialization(materializeCast); + // typeConverter4Func.addTargetMaterialization(materializeCast); ::mlir::ConversionTarget target(ctxt); // We convert all NDArray stuff... target.addIllegalDialect<::imex::ndarray::NDArrayDialect>(); // ...into Linalg, Affine, Tensor, Arith target.addLegalDialect< - ::mlir::linalg::LinalgDialect, ::mlir::affine::AffineDialect, - ::mlir::arith::ArithDialect, ::mlir::math::MathDialect, + ::mlir::linalg::LinalgDialect, ::mlir::arith::ArithDialect, ::mlir::memref::MemRefDialect, ::mlir::tensor::TensorDialect, - ::mlir::tosa::TosaDialect, ::mlir::shape::ShapeDialect, ::mlir::bufferization::BufferizationDialect, ::imex::region::RegionDialect>(); - target.addLegalOp<::mlir::UnrealizedConversionCastOp>(); // FIXME + // target.addLegalOp<::mlir::UnrealizedConversionCastOp>(); // FIXME - // make sure function boundaries use tensors (not NDArrays) - target.addDynamicallyLegalOp<::mlir::func::FuncOp>( - [&](::mlir::func::FuncOp op) { - return typeConverter4Func.isSignatureLegal(op.getFunctionType()) && - typeConverter4Func.isLegal(&op.getBody()); - }); - target.addDynamicallyLegalOp<::mlir::func::ReturnOp, mlir::func::CallOp>( - [&](mlir::Operation *op) { return typeConverter4Func.isLegal(op); }); + // // make sure function boundaries use tensors (not NDArrays) + // target.addDynamicallyLegalOp<::mlir::func::FuncOp>( + // [&](::mlir::func::FuncOp op) { + // return typeConverter4Func.isSignatureLegal(op.getFunctionType()) && + // typeConverter4Func.isLegal(&op.getBody()); + // }); + // target.addDynamicallyLegalOp<::mlir::func::ReturnOp, mlir::func::CallOp>( + // [&](mlir::Operation *op) { return typeConverter4Func.isLegal(op); }); - target.addDynamicallyLegalOp<::imex::region::EnvironmentRegionOp, - ::imex::region::EnvironmentRegionYieldOp>( - [&](mlir::Operation *op) { return typeConverter.isLegal(op); }); + // target.addDynamicallyLegalOp<::imex::region::EnvironmentRegionOp, + // ::imex::region::EnvironmentRegionYieldOp>( + // [&](mlir::Operation *op) { return typeConverter.isLegal(op); }); ::mlir::RewritePatternSet patterns(&ctxt); - patterns.insert(typeConverter, - &ctxt); - ::imex::populateRegionTypeConversionPatterns(patterns, typeConverter); - - // populate function boundaries using our special type converter - ::mlir::populateFunctionOpInterfaceTypeConversionPattern< - ::mlir::func::FuncOp>(patterns, typeConverter4Func); - ::mlir::populateReturnOpTypeConversionPattern(patterns, typeConverter4Func); - ::mlir::populateCallOpTypeConversionPattern(patterns, typeConverter4Func); - - ::mlir::scf::populateSCFStructuralTypeConversionsAndLegality( - typeConverter, patterns, target); + patterns + .insert(&ctxt); + // ::imex::populateRegionTypeConversionPatterns(patterns, typeConverter); + + // // populate function boundaries using our special type converter + // ::mlir::populateFunctionOpInterfaceTypeConversionPattern< + // ::mlir::func::FuncOp>(patterns, typeConverter4Func); + // ::mlir::populateReturnOpTypeConversionPattern(patterns, + // typeConverter4Func); + // ::mlir::populateCallOpTypeConversionPattern(patterns, + // typeConverter4Func); + + // ::mlir::scf::populateSCFStructuralTypeConversionsAndLegality( + // typeConverter, patterns, target); if (::mlir::failed(::mlir::applyPartialConversion(getOperation(), target, ::std::move(patterns)))) { diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index 8fd29c73f..6d91a7df7 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -1,4 +1,3 @@ -add_subdirectory(Dist) add_subdirectory(DistRuntime) add_subdirectory(NDArray) add_subdirectory(Region) diff --git a/lib/Dialect/Dist/CMakeLists.txt b/lib/Dialect/Dist/CMakeLists.txt deleted file mode 100644 index 9f57627c3..000000000 --- a/lib/Dialect/Dist/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(IR) -add_subdirectory(Transforms) diff --git a/lib/Dialect/Dist/IR/CMakeLists.txt b/lib/Dialect/Dist/IR/CMakeLists.txt deleted file mode 100644 index b0a0f8be8..000000000 --- a/lib/Dialect/Dist/IR/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -add_imex_dialect_library(IMEXDistDialect - DistOps.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/mlir/Dialect/Dist - - DEPENDS - MLIRDistOpsIncGen - - LINK_LIBS PUBLIC - MLIRIR -) diff --git a/lib/Dialect/Dist/IR/DistOps.cpp b/lib/Dialect/Dist/IR/DistOps.cpp deleted file mode 100644 index de9f654f7..000000000 --- a/lib/Dialect/Dist/IR/DistOps.cpp +++ /dev/null @@ -1,229 +0,0 @@ -//===- DistOps.cpp - Dist dialect ------------------------------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements the Dist dialect and its basic operations. -/// -//===----------------------------------------------------------------------===// - -#include -#include -#include -#include -#include -#include -#include - -namespace imex { -namespace dist { - -void DistDialect::initialize() { - addTypes< -#define GET_TYPEDEF_LIST -#include - >(); - addOperations< -#define GET_OP_LIST -#include - >(); - addAttributes< -#define GET_ATTRDEF_LIST -#include - >(); -} - -} // namespace dist -} // namespace imex - -static mlir::LogicalResult -parseDistEnv(mlir::AsmParser &parser, ::mlir::Attribute &team, - llvm::SmallVector &lOffsets, - llvm::SmallVector> &lshapes) { - llvm::SmallVector> dimensions; - llvm::SmallVector dims; - llvm::SmallVector lOffs; - - std::string tmp; - if (parser.parseKeyword("team")) { - return mlir::failure(); - } - if (parser.parseEqual()) { - return mlir::failure(); - } - if (parser.parseAttribute(team)) { - return mlir::failure(); - } - - if (parser.parseOptionalKeyword("loffs")) { - dimensions.push_back({}); - } else { - if (parser.parseEqual()) { - return mlir::failure(); - } - if (parser.parseCommaSeparatedList([&]() { - int64_t v = ::mlir::ShapedType::kDynamic; - auto opr = parser.parseOptionalInteger(v); - if (!opr.has_value()) { - if (parser.parseQuestion()) { - return mlir::failure(); - } - } - lOffs.emplace_back(v); - return mlir::success(); - })) { - return mlir::failure(); - } - auto n = lOffs.size(); - - if (parser.parseKeyword("lparts")) { - return mlir::failure(); - } - if (parser.parseEqual()) { - return mlir::failure(); - } - auto prs = [&]() { - if (parser.parseDimensionList(dims, true, false) || dims.size() != n) { - return mlir::failure(); - } - dimensions.emplace_back(dims); - dims.clear(); - return mlir::success(); - }; - if (parser.parseCommaSeparatedList(prs) || - !(dimensions.size() == 1 || dimensions.size() == 3)) { - return mlir::failure(); - } - } - - lOffsets = std::move(lOffs); - lshapes = std::move(dimensions); - return mlir::success(); -} - -static void -printDistEnv(mlir::AsmPrinter &printer, ::mlir::Attribute team, - const llvm::ArrayRef lOffs, - const llvm::SmallVector> lshapes) { - if (team) { - printer << "team = " << team; - - auto n = lOffs.size(); - if (n) { - auto printEl = [&](int64_t v, char sep, bool last) { - if (v == ::mlir::ShapedType::kDynamic) { - printer << '?'; - } else { - printer << v; - } - if (!last) - printer << sep; - }; - - printer << " loffs = "; - for (size_t i = 0; i < n; ++i) { - printEl(lOffs[i], ',', i >= n - 1); - } - - n = lshapes.size(); - printer << " lparts = "; - for (size_t i = 0; i < n; ++i) { - auto shape = lshapes[i]; - for (size_t j = 0; j < shape.size(); ++j) { - printEl(shape[j], 'x', j >= shape.size() - 1); - } - if (i < n - 1) { - printer << ','; - } - } - } - } -} - -#include -#define GET_TYPEDEF_CLASSES -#include -#define GET_ATTRDEF_CLASSES -#include -#define GET_OP_CLASSES -#include - -namespace imex { -namespace dist { - -DistEnvAttr DistEnvAttr::get( - ::mlir::Attribute team, ::llvm::ArrayRef lOffsets, - ::mlir::SmallVector<::mlir::SmallVector> partsShapes) { - assert(partsShapes.size() == 3 || partsShapes.size() == 1); - assert(team); - return get(team.getContext(), team, lOffsets, partsShapes); -} - -DistEnvAttr DistEnvAttr::get(::mlir::Attribute team, int64_t rank) { - assert(team); - ::mlir::SmallVector<::mlir::SmallVector> partsShapes( - rank ? 3 : 1, - ::mlir::SmallVector(rank, ::mlir::ShapedType::kDynamic)); - ::mlir::SmallVector lOffsets(rank, ::mlir::ShapedType::kDynamic); - return get(team.getContext(), team, lOffsets, partsShapes); -} - -DistEnvAttr DistEnvAttr::cloneWithDynOffsAndDims() const { - return get(getTeam(), getLOffsets().size()); -} - -void InitDistArrayOp::build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, - ::mlir::Attribute team, - ::mlir::ArrayRef g_shape, - ::mlir::ValueRange l_offset, - ::mlir::ValueRange parts, - ::mlir::ArrayRef<::mlir::Attribute> environments, - ::mlir::ArrayRef s_Offs) { - assert(l_offset.size() == g_shape.size()); - auto elTyp = mlir::cast<::imex::ndarray::NDArrayType>(parts.front().getType()) - .getElementType(); - ::mlir::SmallVector<::mlir::SmallVector> shapes; - for (auto p : parts) { - assert(!isDist(p)); - shapes.emplace_back( - mlir::cast<::imex::ndarray::NDArrayType>(p.getType()).getShape()); - } - auto resShape = getShapeFromValues(l_offset); - ::mlir::ArrayRef lOffs = s_Offs.size() ? s_Offs : resShape; - ::mlir::SmallVector<::mlir::Attribute> nEnvs(environments); - nEnvs.emplace_back(::imex::dist::DistEnvAttr::get(team, lOffs, shapes)); - auto arType = ::imex::ndarray::NDArrayType::get(g_shape, elTyp, nEnvs); - build(odsBuilder, odsState, arType, l_offset, parts); -} - -void PartsOfOp::build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, ::mlir::Value ary) { - auto pTypes = - getPartsTypes(mlir::cast<::imex::ndarray::NDArrayType>(ary.getType())); - assert(pTypes.size() == 1 || pTypes.size() == 3 || - (false && "Number of local parts must be 1 or 3")); - build(odsBuilder, odsState, pTypes, ary); -} - -::mlir::LogicalResult PartsOfOp::verify() { - if (this->getNumResults() == 1 || (this->getNumResults() == 3)) { - return ::mlir::success(); - } - return ::mlir::failure(); -} - -::mlir::LogicalResult EWBinOp::verify() { - if (isDist(getResult()) && isDist(getLhs()) && isDist(getRhs())) { - return ::mlir::success(); - } - return ::mlir::failure(); -} - -} // namespace dist -} // namespace imex diff --git a/lib/Dialect/Dist/Transforms/CMakeLists.txt b/lib/Dialect/Dist/Transforms/CMakeLists.txt deleted file mode 100644 index 5363b8540..000000000 --- a/lib/Dialect/Dist/Transforms/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -add_imex_dialect_library(IMEXDistTransforms - DistCoalesce.cpp - DistInferElementwiseCores.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/imex/Dialect/Dist - - DEPENDS - IMEXDistPassIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - IMEXDistDialect -) diff --git a/lib/Dialect/Dist/Transforms/DistInferElementwiseCores.cpp b/lib/Dialect/Dist/Transforms/DistInferElementwiseCores.cpp deleted file mode 100644 index ae3309b7e..000000000 --- a/lib/Dialect/Dist/Transforms/DistInferElementwiseCores.cpp +++ /dev/null @@ -1,313 +0,0 @@ -//===- DistInferEWCores.cpp - DistInferEWCores Transform ------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements inferring core loops for elementwise operations. -//===----------------------------------------------------------------------===// - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace imex { -#define GEN_PASS_DEF_DISTINFEREWCORES -#include "imex/Dialect/Dist/Transforms/Passes.h.inc" -} // namespace imex - -namespace imex { -namespace dist { - -namespace { - -struct DistInferEWCoresPass - : public imex::impl::DistInferEWCoresBase { - - DistInferEWCoresPass() = default; - - static bool isEW(::mlir::Operation *op) { // FIXME use interface or such - return (::mlir::isa<::imex::dist::EWBinOp>(op) || - ::mlir::isa<::imex::dist::EWUnyOp>(op)); - }; - - static bool hasCore(::mlir::Operation *op) { - assert(isEW(op)); - return op->getNumOperands() > 2; // FIXME use interface or such - } - - std::tuple<::imex::ValVec, ::imex::ValVec, ::imex::ValVec> static getCore( - ::mlir::Operation *op) { - if (auto typedOp = ::mlir::dyn_cast<::imex::dist::EWBinOp>(op)) { - return {typedOp.getCoreOffsets(), typedOp.getCoreSizes(), - typedOp.getTargetOffsets()}; - } else if (auto typedOp = ::mlir::dyn_cast<::imex::dist::EWUnyOp>(op)) { - return {typedOp.getCoreOffsets(), typedOp.getCoreSizes(), - typedOp.getTargetOffsets()}; - } - assert("Expected ewop"); - return {}; - } - - // Adds local core to all dependent ewops of given ewop. - // Dependent ewops are the ewop itself, operands and users. - // Stops when visiting non-ewop. - // Adds visited ewops to visited. - // Adds ewops to alien if the ewop has a different core. - void propagateAddLocalCore( - ::mlir::IRRewriter &builder, ::mlir::Operation *op, - ::mlir::Operation *lcOp, ::imex::ValVec &coreOffs, - ::imex::ValVec &coreSzs, ::imex::ValVec &targetOffs, - ::std::set<::mlir::Operation *> &visited, - ::std::set<::mlir::Operation *, ::imex::opOrderCmp> &alien) { - auto &dom = this->getAnalysis<::mlir::DominanceInfo>(); - if (!dom.dominates(lcOp, op)) { - return; - } - - // add core to all ewops that we visited - if (hasCore(op)) { - auto core = getCore(op); - auto aCoreOffs = std::get<0>(core); - auto aCoreSzs = std::get<1>(core); - auto aTargetOffs = std::get<2>(core); - if (coreOffs != aCoreOffs || coreSzs != aCoreSzs || - targetOffs != aTargetOffs) { - alien.emplace(op); - } - } else { - op->insertOperands(op->getNumOperands(), coreOffs); - op->insertOperands(op->getNumOperands(), coreSzs); - op->insertOperands(op->getNumOperands(), targetOffs); - } - visited.emplace(op); - - // we need to back-propagate to operands - for (int i = 0; i < (mlir::isa(op) ? 2 : 1); ++i) { - auto oprnd = op->getOperand(i).getDefiningOp(); - if (oprnd && isEW(oprnd) && visited.find(oprnd) == visited.end()) { - propagateAddLocalCore(builder, oprnd, lcOp, coreOffs, coreSzs, - targetOffs, visited, alien); - } - } - - // forward to dependent uses - for (auto user : op->getUsers()) { - if (isEW(user) && visited.find(user) == visited.end()) { - propagateAddLocalCore(builder, user, lcOp, coreOffs, coreSzs, - targetOffs, visited, alien); - } - } - } - - // pull operation up to last defining op/producer - void pullOp(::mlir::Operation *op, ::mlir::Operation *barrier = nullptr) { - std::vector<::mlir::Operation *> deps; - for (auto o : op->getOperands()) { - auto defOp = o.getDefiningOp(); - if (defOp) { - deps.emplace_back(defOp); - } - } - if (barrier) { - deps.emplace_back(barrier); - } - assert(!deps.empty()); - auto &dom = this->getAnalysis<::mlir::DominanceInfo>(); - std::sort(deps.begin(), deps.end(), ::imex::opOrderCmp(dom)); - op->moveAfter(deps.back()); - } - - void runOnOperation() override { - auto root = this->getOperation(); - // first run canonicalizer - ::mlir::PassManager pm(&getContext()); - // Add the canonicalization pass. - pm.addPass(::mlir::createCanonicalizerPass()); - // Run the PassManager. - if (::mlir::failed(pm.run(root))) { - signalPassFailure(); - } - - ::mlir::IRRewriter builder(&getContext()); - ::mlir::SmallVector<::mlir::Operation *> rpOps, bbOps; - - // find all ewops - root->walk([&](::mlir::Operation *op) { - if (auto typedOp = ::mlir::dyn_cast<::imex::dist::RePartitionOp>(op)) { - auto arTyp = mlir::cast<::imex::ndarray::NDArrayType>( - typedOp.getResult().getType()); - if (!arTyp.hasUnitSize() && !arTyp.hasZeroSize()) { - rpOps.emplace_back(op); - } - } else if (::mlir::isa<::imex::dist::LocalTargetOfSliceOp, - ::imex::dist::LocalBoundingBoxOp>(op)) { - bbOps.emplace_back(op); - } - }); - - // pull up all LocalTargetOfSliceOps and BoundingBoxOps as much as possible. - for (auto bbOp : bbOps) { - pullOp(bbOp); - } - bbOps.clear(); - - // recursively traverse all ewops and insert localcoreop and update ewbinops - // accordingly - for (auto rpOp : rpOps) { - ::std::unordered_multimap<::mlir::Operation *, ::imex::dist::LocalCoreOp> - lcOps; - // find all subviews on the repartitioned array to add localcoreops - std::vector<::mlir::Operation *> users(rpOp->getUsers().begin(), - rpOp->getUsers().end()); - std::sort(users.begin(), users.end(), - opOrderCmp(this->getAnalysis<::mlir::DominanceInfo>())); - auto base = ::mlir::cast<::imex::dist::RePartitionOp>(rpOp).getArray(); - for (auto user : users) { - if (auto svOp = ::mlir::dyn_cast<::imex::dist::SubviewOp>(user)) { - auto arTyp = mlir::cast<::imex::ndarray::NDArrayType>( - svOp.getResult().getType()); - - if (!arTyp.hasUnitSize() && !arTyp.hasZeroSize()) { - ::imex::dist::LocalCoreOp nlcOp; - // check if view has a ewop as user - for (auto svUse : svOp->getUsers()) { - if (isEW(svUse)) { - auto loc = svOp->getLoc(); - auto lcOp = lcOps.find(svUse); - ::imex::ValVec offsets, sizes, strides; - // constants should go to the beginning of the block - builder.setInsertionPointToStart(svOp->getBlock()); - offsets = ::imex::getMixedAsValues( - loc, builder, svOp.getOffsets(), svOp.getStaticOffsets()); - sizes = ::imex::getMixedAsValues(loc, builder, svOp.getSizes(), - svOp.getStaticSizes()); - strides = ::imex::getMixedAsValues( - loc, builder, svOp.getStrides(), svOp.getStaticStrides()); - - auto resOffsets = - lcOp == lcOps.end() - ? ::mlir::ValueRange{} - : ::mlir::ValueRange(lcOp->second.getResultOffsets()); - auto resSizes = - lcOp == lcOps.end() - ? ::mlir::ValueRange{} - : ::mlir::ValueRange(lcOp->second.getResultSizes()); - // we start with localcoreop right after subviewop... - builder.setInsertionPointAfter(svOp); - nlcOp = builder.create<::imex::dist::LocalCoreOp>( - loc, base, svOp.getTargetOffsets(), svOp.getTargetSizes(), - offsets, sizes, strides, resOffsets, resSizes); - // ...and pull it up as much as possible - pullOp(nlcOp); - lcOps.emplace(svUse, nlcOp); - } - } - } - } - } - - // for each "initial" ewop add core to dependent ewops and update for - // alien ops - for (auto lcOp : lcOps) { - ::std::set<::mlir::Operation *> visited; - auto &dom = this->getAnalysis<::mlir::DominanceInfo>(); - ::std::set<::mlir::Operation *, ::imex::opOrderCmp> alien{ - ::imex::opOrderCmp(dom)}; - - auto ewHasCore = hasCore(lcOp.first); - - ::imex::ValVec coreOffs = lcOp.second.getResultOffsets(); - ::imex::ValVec coreSzs = lcOp.second.getResultSizes(); - ::imex::ValVec targetOffs = lcOp.second.getTargetOffsets(); - - propagateAddLocalCore(builder, lcOp.first, lcOp.second, coreOffs, - coreSzs, targetOffs, visited, alien); - - // update local cores if we found "alien" ewops - if (!alien.empty()) { - auto loc = rpOp->getLoc(); - auto currCore = getCore(lcOp.first); - coreOffs = std::get<0>(currCore); - coreSzs = std::get<1>(currCore); - - ::std::set<::mlir::Operation *> lcOpsDone; - - for (auto a : alien) { - auto core = - getCore(a); // should be the same as currCore if ewHasCore - auto aCoreOffs = std::get<0>(core); - auto defOp = - ewHasCore - ? lcOp.second - : aCoreOffs[0].getDefiningOp<::imex::dist::LocalCoreOp>(); - - if (lcOpsDone.find(defOp) == lcOpsDone.end()) { - lcOpsDone.emplace(defOp); - auto aCoreSzs = std::get<1>(core); - auto aTargetOffs = std::get<2>(core); - auto nlcOp = builder.create<::imex::dist::LocalCoreOp>( - loc, defOp.getArray(), defOp.getTargetOffsets(), - defOp.getTargetSizes(), defOp.getSliceOffsets(), - defOp.getSliceSizes(), defOp.getSliceStrides(), coreOffs, - coreSzs); - ::mlir::SmallVector<::mlir::Operation *> deps; - for (auto o : nlcOp->getOperands()) { - if (auto tmp = o.getDefiningOp()) { - deps.emplace_back(tmp); - } - } - std::sort(deps.begin(), deps.end(), ::imex::opOrderCmp(dom)); - nlcOp->moveAfter(deps.back()); - - coreOffs = nlcOp.getResultOffsets(); - coreSzs = nlcOp.getResultSizes(); - - if (ewHasCore) { - // if there was already a core then we had updated all deps - // already - // -> all deps should have the same core -> no need to handle - // other aliens - continue; - } - } - } - - // update full chain of ewops with new core - auto rank = coreOffs.size(); - for (auto vop : visited) { - auto cStart = ::mlir::isa<::imex::dist::EWBinOp>(vop) ? 2 : 1; - vop->setOperands(cStart, rank, coreOffs); - vop->setOperands(cStart + rank, rank, coreSzs); - } - } - } - } - }; // runOnOperation() -}; // DistInferEWCoresPass - -} // namespace -} // namespace dist - -/// Create DistInferEWBinopPass -std::unique_ptr<::mlir::Pass> createDistInferEWCoresPass() { - return std::make_unique<::imex::dist::DistInferEWCoresPass>(); -} - -} // namespace imex diff --git a/lib/Dialect/DistRuntime/IR/CMakeLists.txt b/lib/Dialect/DistRuntime/IR/CMakeLists.txt index ec6c84e7a..476c373a3 100644 --- a/lib/Dialect/DistRuntime/IR/CMakeLists.txt +++ b/lib/Dialect/DistRuntime/IR/CMakeLists.txt @@ -1,6 +1,5 @@ add_imex_dialect_library(IMEXDistRuntimeDialect DistRuntimeOps.cpp - GetHaloOp.cpp CopyReshapeOp.cpp CopyPermuteOp.cpp diff --git a/lib/Dialect/DistRuntime/IR/CopyPermuteOp.cpp b/lib/Dialect/DistRuntime/IR/CopyPermuteOp.cpp index 93e972017..602372a15 100644 --- a/lib/Dialect/DistRuntime/IR/CopyPermuteOp.cpp +++ b/lib/Dialect/DistRuntime/IR/CopyPermuteOp.cpp @@ -42,8 +42,7 @@ class CopyPermuteOpResultCanonicalizer final // check input type auto dstArray = op.getNlArray(); - auto dstType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(dstArray.getType()); + auto dstType = mlir::dyn_cast<::mlir::RankedTensorType>(dstArray.getType()); if (!dstType) { return ::mlir::failure(); } @@ -63,8 +62,8 @@ class CopyPermuteOpResultCanonicalizer final op.getNlShape(), op.getAxes()); // cast to original types and replace op - auto res = rewriter.create(op.getLoc(), dstType, - newOp.getNlArray()); + auto res = rewriter.create<::mlir::tensor::CastOp>(op.getLoc(), dstType, + newOp.getNlArray()); rewriter.replaceOp(op, {newOp.getHandle(), res}); return ::mlir::success(); @@ -83,11 +82,11 @@ class CopyPermuteCastFolder final matchAndRewrite(::imex::distruntime::CopyPermuteOp op, mlir::PatternRewriter &rewriter) const override { auto src = op.getLArray(); - auto castOp = mlir::dyn_cast(src.getDefiningOp()); + auto castOp = mlir::dyn_cast<::mlir::tensor::CastOp>(src.getDefiningOp()); if (!castOp) return mlir::failure(); - if (!imex::ndarray::canFoldIntoConsumerOp(castOp)) + if (!mlir::tensor::canFoldIntoConsumerOp(castOp)) return mlir::failure(); auto newOp = rewriter.create<::imex::distruntime::CopyPermuteOp>( diff --git a/lib/Dialect/DistRuntime/IR/CopyReshapeOp.cpp b/lib/Dialect/DistRuntime/IR/CopyReshapeOp.cpp index 658c9f2f6..182861cae 100644 --- a/lib/Dialect/DistRuntime/IR/CopyReshapeOp.cpp +++ b/lib/Dialect/DistRuntime/IR/CopyReshapeOp.cpp @@ -42,8 +42,7 @@ class CopyReshapeOpResultCanonicalizer final // check input type auto nlArray = op.getNlArray(); - auto nlType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(nlArray.getType()); + auto nlType = mlir::dyn_cast<::mlir::RankedTensorType>(nlArray.getType()); if (!nlType) { return ::mlir::failure(); } @@ -76,8 +75,8 @@ class CopyReshapeOpResultCanonicalizer final op.getNlOffsets(), op.getNlShape()); // cast to original types and replace op - auto res = rewriter.create(op.getLoc(), nlType, - newOp.getNlArray()); + auto res = rewriter.create(op.getLoc(), nlType, + newOp.getNlArray()); rewriter.replaceOp(op, {newOp.getHandle(), res}); return ::mlir::success(); diff --git a/lib/Dialect/DistRuntime/IR/GetHaloOp.cpp b/lib/Dialect/DistRuntime/IR/GetHaloOp.cpp deleted file mode 100644 index 7938b7c42..000000000 --- a/lib/Dialect/DistRuntime/IR/GetHaloOp.cpp +++ /dev/null @@ -1,180 +0,0 @@ -//===- GetHaloOp.cpp - distruntime dialect ---------------------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements the GetHaloOp of the DistRuntime dialect. -/// -//===----------------------------------------------------------------------===// - -#include -#include -#include - -namespace imex { -namespace distruntime { - -void GetHaloOp::build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, ::mlir::Value local, - ::mlir::ValueRange gShape, ::mlir::ValueRange lOffsets, - ::mlir::ValueRange bbOffsets, ::mlir::ValueRange bbSizes, - ::mlir::ValueRange lHSizes, ::mlir::ValueRange rHSizes, - ::mlir::Attribute team, int64_t key) { - auto lShp = getShapeFromValues(lHSizes); - auto rShp = getShapeFromValues(rHSizes); - auto arType = mlir::cast<::imex::ndarray::NDArrayType>(local.getType()); - auto elType = arType.getElementType(); - build(odsBuilder, odsState, - ::imex::distruntime::AsyncHandleType::get(elType.getContext()), - arType.cloneWith(lShp, elType), arType.cloneWith(rShp, elType), local, - gShape, lOffsets, bbOffsets, bbSizes, team, - odsBuilder.getI64IntegerAttr(key)); -} - -::mlir::SmallVector<::mlir::Value> GetHaloOp::getDependent() { - return {getLHalo(), getRHalo()}; -} - -} // namespace distruntime -} // namespace imex - -namespace { - -/// Pattern to replace dynamically shaped result halo types -/// by statically shaped halo result types. -/// It is assumed that for unit-sized ndarrays the halo sizes have static sizes -/// always. This is a slightly complicated canonicalization because it requires -/// computing the static sizes of the halos. -class GetHaloOpResultCanonicalizer final - : public mlir::OpRewritePattern<::imex::distruntime::GetHaloOp> { -public: - using mlir::OpRewritePattern< - ::imex::distruntime::GetHaloOp>::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(::imex::distruntime::GetHaloOp op, - ::mlir::PatternRewriter &rewriter) const override { - - // check input type - auto lData = op.getLocal(); - auto lType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(lData.getType()); - auto rank = lType.getRank(); - if (!lType || rank == 0) - return ::mlir::failure(); - - // local data type - auto arType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(lData.getType()); - auto lSizes = arType.getShape(); - - // if dyn type, check if this came from a CastOp - if (::mlir::ShapedType::isDynamicShape(lSizes)) { - if (auto defOp = lData.getDefiningOp<::imex::ndarray::CastOp>()) { - lData = defOp.getSource(); - arType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(lData.getType()); - lSizes = arType.getShape(); - } - } - - // Get current halos types and shapes - auto lHType = - mlir::cast<::imex::ndarray::NDArrayType>(op.getLHalo().getType()); - auto rHType = - mlir::cast<::imex::ndarray::NDArrayType>(op.getRHalo().getType()); - auto lResSzs = lHType.getShape(); - auto rResSzs = rHType.getShape(); - auto lDyn = ::mlir::ShapedType::isDynamicShape(lResSzs); - auto rDyn = ::mlir::ShapedType::isDynamicShape(rResSzs); - - // nothing to do if the result types are already static - if (!(lDyn || rDyn)) { - return ::mlir::failure(); - } - - // Get all dependent values needed to compute halo sizes (bb and loffsets) - bool moded = false; - auto lOffsets = ::imex::getShapeFromValues(op.getLOffsets()); - auto bbOffs = ::imex::getShapeFromValues(op.getBbOffsets()); - auto bbSizes = ::imex::getShapeFromValues(op.getBbSizes()); - - lDyn = ::mlir::ShapedType::isDynamicShape(lResSzs[0]); - rDyn = ::mlir::ShapedType::isDynamicShape(rResSzs[0]); - - ::mlir::SmallVector lHSizes(lResSzs), rHSizes(rResSzs); - // if the first (split) dim is non-constant for any halo -> try to determine - // their size in first dim - if (lDyn || rDyn) { - // all dependent values for computation - auto bbOff = bbOffs[0]; - auto bbSize = bbSizes[0]; - auto oldOff = lOffsets[0]; - auto oldSize = lSizes[0]; - - // only if all are statically known we can compute size in first dim - if (!::mlir::ShapedType::isDynamic(bbOff) && - !::mlir::ShapedType::isDynamic(bbSize) && - !::mlir::ShapedType::isDynamic(oldOff) && - !::mlir::ShapedType::isDynamic(oldSize)) { - auto tEnd = bbOff + bbSize; - auto oldEnd = oldOff + oldSize; - auto ownOff = std::max(oldOff, bbOff); - auto ownSize = std::max(std::min(oldEnd, tEnd) - ownOff, 0L); - - lHSizes[0] = std::min(ownOff, tEnd) - bbOff; - rHSizes[0] = std::max(tEnd - (ownOff + ownSize), 0L); - moded = true; - } - } - - // all other dims: if not statically known already check if bb size is - // statically known - for (auto i = 1; i < rank; ++i) { - if (!::mlir::ShapedType::isDynamic(bbSizes[i])) { - if (::mlir::ShapedType::isDynamic(lHSizes[i])) { - lHSizes[i] = bbSizes[i]; - moded = true; - } - if (::mlir::ShapedType::isDynamic(rHSizes[i])) { - rHSizes[i] = bbSizes[i]; - moded = true; - } - } - } - - // no new static size determined? - if (!moded) { - return ::mlir::failure(); - } - - // make new halo types and create new GetHaloOp - auto elTyp = lType.getElementType(); - auto lTyp = lType.cloneWith(lHSizes, elTyp); - auto rTyp = lType.cloneWith(rHSizes, elTyp); - - auto newOp = rewriter.create<::imex::distruntime::GetHaloOp>( - op.getLoc(), - ::imex::distruntime::AsyncHandleType::get(lTyp.getContext()), lTyp, - rTyp, lData, op.getGShape(), op.getLOffsets(), op.getBbOffsets(), - op.getBbSizes(), op.getTeamAttr(), op.getKeyAttr()); - - // cast to original types and replace op - auto lH = rewriter.create(op.getLoc(), lHType, - newOp.getLHalo()); - auto rH = rewriter.create(op.getLoc(), rHType, - newOp.getRHalo()); - rewriter.replaceOp(op, {newOp.getHandle(), lH, rH}); - - return ::mlir::success(); - } -}; - -} // namespace - -void imex::distruntime::GetHaloOp::getCanonicalizationPatterns( - mlir::RewritePatternSet &results, mlir::MLIRContext *context) { - results.add(context); -} diff --git a/lib/Dialect/DistRuntime/Transforms/AddCommCacheKeys.cpp b/lib/Dialect/DistRuntime/Transforms/AddCommCacheKeys.cpp deleted file mode 100644 index ffa6a5308..000000000 --- a/lib/Dialect/DistRuntime/Transforms/AddCommCacheKeys.cpp +++ /dev/null @@ -1,51 +0,0 @@ -//===- AddCommCacheKeys.cpp - OverlapCommAndCompute Transform *- C++ -*-// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements adding unique keys to update_halo ops for caching. -/// -//===----------------------------------------------------------------------===// - -#include -#include - -namespace imex { -#define GEN_PASS_DEF_ADDCOMMCACHEKEYS -#include "imex/Dialect/DistRuntime/Transforms/Passes.h.inc" -} // namespace imex - -namespace imex { -namespace distruntime { - -namespace { - -struct AddCommCacheKeysPass - : public imex::impl::AddCommCacheKeysBase { - - AddCommCacheKeysPass() = default; - - /// @brief Add unique cache key to every distruntime::GetHaloOp - void runOnOperation() override { - auto root = this->getOperation(); - static int64_t key = -1; - - // find all GetHaloOps and assign a unique key to each instance - root->walk([&](::imex::distruntime::GetHaloOp op) { op.setKey(++key); }); - } -}; - -} // namespace -} // namespace distruntime - -/// Create a pass to eliminate Dist ops -std::unique_ptr<::mlir::Pass> createAddCommCacheKeysPass() { - return std::make_unique<::imex::distruntime::AddCommCacheKeysPass>(); -} - -} // namespace imex diff --git a/lib/Dialect/DistRuntime/Transforms/CMakeLists.txt b/lib/Dialect/DistRuntime/Transforms/CMakeLists.txt index 034ed96be..31a265fc9 100644 --- a/lib/Dialect/DistRuntime/Transforms/CMakeLists.txt +++ b/lib/Dialect/DistRuntime/Transforms/CMakeLists.txt @@ -1,7 +1,5 @@ add_imex_dialect_library(IMEXDistRuntimeTransforms DistRuntimeToIDTR.cpp - OverlapCommAndCompute.cpp - AddCommCacheKeys.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/imex/Dialect/DistRuntime @@ -12,5 +10,4 @@ add_imex_dialect_library(IMEXDistRuntimeTransforms LINK_LIBS PUBLIC MLIRIR MLIRPass - IMEXDistTransforms ) diff --git a/lib/Dialect/DistRuntime/Transforms/DistRuntimeToIDTR.cpp b/lib/Dialect/DistRuntime/Transforms/DistRuntimeToIDTR.cpp index d83171708..598be757f 100644 --- a/lib/Dialect/DistRuntime/Transforms/DistRuntimeToIDTR.cpp +++ b/lib/Dialect/DistRuntime/Transforms/DistRuntimeToIDTR.cpp @@ -11,7 +11,6 @@ /// This file implements lowering distruntime ops to calls to IDTR. //===----------------------------------------------------------------------===// -#include #include #include #include @@ -122,46 +121,6 @@ struct RuntimePrototypes { } }; -/// Convert ::imex::distruntime::TeamSizeOp into call to _idtr_nprocs -struct TeamSizeOpPattern - : public ::mlir::OpRewritePattern<::imex::distruntime::TeamSizeOp> { - using ::mlir::OpRewritePattern< - ::imex::distruntime::TeamSizeOp>::OpRewritePattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::distruntime::TeamSizeOp op, - ::mlir::PatternRewriter &rewriter) const override { - auto team = mlir::cast<::mlir::IntegerAttr>(op.getTeam()).getInt(); - auto loc = op.getLoc(); - - rewriter.replaceOpWithNewOp<::mlir::func::CallOp>( - op, "_idtr_nprocs", rewriter.getIndexType(), - createInt(loc, rewriter, team)); - - return ::mlir::success(); - } -}; - -/// Convert ::imex::distruntime::TeamMemberOp into call to _idtr_prank -struct TeamMemberOpPattern - : public ::mlir::OpRewritePattern<::imex::distruntime::TeamMemberOp> { - using ::mlir::OpRewritePattern< - ::imex::distruntime::TeamMemberOp>::OpRewritePattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::distruntime::TeamMemberOp op, - ::mlir::PatternRewriter &rewriter) const override { - auto team = mlir::cast<::mlir::IntegerAttr>(op.getTeam()).getInt(); - auto loc = op.getLoc(); - - rewriter.replaceOpWithNewOp<::mlir::func::CallOp>( - op, "_idtr_prank", rewriter.getIndexType(), - createInt(loc, rewriter, team)); - - return ::mlir::success(); - } -}; - struct CopyReshapeOpPattern : public ::mlir::OpRewritePattern<::imex::distruntime::CopyReshapeOp> { using ::mlir::OpRewritePattern< @@ -171,10 +130,9 @@ struct CopyReshapeOpPattern matchAndRewrite(::imex::distruntime::CopyReshapeOp op, ::mlir::PatternRewriter &rewriter) const override { auto lArray = op.getLArray(); - auto arType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(lArray.getType()); + auto arType = mlir::dyn_cast<::mlir::RankedTensorType>(lArray.getType()); auto resType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getNlArray().getType()); + mlir::dyn_cast<::mlir::RankedTensorType>(op.getNlArray().getType()); if (!arType || !resType) { return ::mlir::failure(); } @@ -191,7 +149,7 @@ struct CopyReshapeOpPattern // create output array with target size auto nlArray = rewriter.create<::imex::ndarray::CreateOp>( loc, nlShape, ::imex::ndarray::fromMLIR(elType), nullptr, - resType.getEnvironments()); + resType.getEncoding()); auto idxType = rewriter.getIndexType(); auto teamC = rewriter.create<::mlir::arith::ConstantOp>( @@ -215,154 +173,6 @@ struct CopyReshapeOpPattern } }; -/// @brief lower GetHaloOp -/// Determine sizes of halos, alloc halos and call idtr. -/// Before accessing/reading from returned halos, the caller must -/// call the appropriate wait call in idtr. -/// @return handle, left halo, right halo -struct GetHaloOpPattern - : public ::mlir::OpRewritePattern<::imex::distruntime::GetHaloOp> { - using ::mlir::OpRewritePattern< - ::imex::distruntime::GetHaloOp>::OpRewritePattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::distruntime::GetHaloOp op, - ::mlir::PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - auto lData = op.getLocal(); - auto arTyp = mlir::dyn_cast<::imex::ndarray::NDArrayType>(lData.getType()); - if (!arTyp) - return ::mlir::failure(); - - auto elType = arTyp.getElementType(); - - auto mkHalo = [&](const ::imex::ValVec &szs) { - ::mlir::Value iVal = -#ifdef DEBUG_HALO - createCast(loc, rewriter, - elType.isIntOrIndex() - ? createInt(loc, rewriter, 4711, - elType.getIntOrFloatBitWidth()) - : createFloat(loc, rewriter, 4711, - elType.getIntOrFloatBitWidth()), - elType); -#else - nullptr; -#endif - auto outPTnsr = rewriter.create<::imex::ndarray::CreateOp>( - loc, szs, ::imex::ndarray::fromMLIR(elType), iVal, - ::imex::dist::getNonDistEnvs(arTyp)); - auto outUMR = ::imex::ndarray::mkURMemRef(loc, rewriter, outPTnsr); - return std::make_pair(outPTnsr, outUMR); - }; - - ::imex::ValVec lSizes = - ::imex::ndarray::createShapeOf(loc, rewriter, lData); - auto lOffsets = op.getLOffsets(); - auto gShape = op.getGShape(); - ::imex::ValVec bbOffs = op.getBbOffsets(); - ::imex::ValVec bbSizes = op.getBbSizes(); - - // Prepare args for calling update_halo - auto idxType = rewriter.getIndexType(); - auto gShapeMR = createURMemRefFromElements(rewriter, loc, idxType, gShape); - auto lOffsMR = createURMemRefFromElements(rewriter, loc, idxType, lOffsets); - // we pass the entire local data to update_halo, not just the subview - auto lPart = ::imex::ndarray::mkURMemRef(loc, rewriter, lData); - auto bbOffsMR = createURMemRefFromElements(rewriter, loc, idxType, bbOffs); - auto bbSizesMR = - createURMemRefFromElements(rewriter, loc, idxType, bbSizes); - - // determine overlap of new local part, we split dim 0 only - auto zero = easyIdx(loc, rewriter, 0); - auto one = easyIdx(loc, rewriter, 1); - auto bbOff = easyIdx(loc, rewriter, bbOffs[0]); - auto bbSize = easyIdx(loc, rewriter, bbSizes[0]); - auto oldOff = easyIdx(loc, rewriter, lOffsets[0]); - auto oldSize = easyIdx(loc, rewriter, lSizes[0]); - auto tEnd = bbOff + bbSize; - auto oldEnd = oldOff + oldSize; - auto ownOff = oldOff.max(bbOff); - auto ownSize = (oldEnd.min(tEnd) - ownOff).max(zero); - - // compute left and right halo sizes, we split dim 0 only - ::imex::ValVec lHSizes(bbSizes), rHSizes(bbSizes); - auto sgShape = getShapeFromValues(gShape); - if (::imex::ndarray::isUnitShape(sgShape)) { - lHSizes[0] = - oldSize.eq(zero).land(oldOff.sgt(zero)).select(one, zero).get(); - rHSizes[0] = - oldSize.eq(zero).land(oldOff.sle(zero)).select(one, zero).get(); - } else { - lHSizes[0] = (ownOff.min(tEnd) - bbOff).get(); - rHSizes[0] = (tEnd - (ownOff + ownSize)).max(zero).get(); - } - - auto lOut = mkHalo(lHSizes); - auto rOut = mkHalo(rHSizes); - auto key = createInt(loc, rewriter, op.getKey()); - - // call our runtime function to redistribute data across processes - auto fun = rewriter.getStringAttr(mkTypedFunc("_idtr_update_halo", elType)); - auto handle = rewriter.create<::mlir::func::CallOp>( - loc, fun, rewriter.getI64Type(), - ::mlir::ValueRange{createInt(loc, rewriter, 0), gShapeMR, lOffsMR, - lPart, bbOffsMR, bbSizesMR, lOut.second, rOut.second, - key}); - - rewriter.replaceOp(op, {handle.getResult(0), lOut.first, rOut.first}); - return ::mlir::success(); - } -}; - -/// Convert ::imex::distruntime::AllReduceOp into runtime call to -/// "_idtr_reduce_all". Pass local RankedTensor as argument. Replaces op with -/// new distributed array. -struct AllReduceOpPattern - : public ::mlir::OpRewritePattern<::imex::distruntime::AllReduceOp> { - using ::mlir::OpRewritePattern< - ::imex::distruntime::AllReduceOp>::OpRewritePattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::distruntime::AllReduceOp op, - ::mlir::PatternRewriter &rewriter) const override { - // get guid and rank and call runtime function - auto loc = op.getLoc(); - auto mRef = op.getData(); - auto mRefType = mlir::dyn_cast<::mlir::MemRefType>(mRef.getType()); - if (!mRefType) - return ::mlir::failure(); - - auto opV = rewriter.create<::mlir::arith::ConstantOp>( - loc, ::mlir::cast<::mlir::TypedAttr>(op.getOp())); - auto elType = mRefType.getElementType(); - - auto fsa = rewriter.getStringAttr(mkTypedFunc("_idtr_reduce_all", elType)); - auto dataUMR = createUnrankedMemRefCast(rewriter, loc, mRef); - - rewriter.replaceOpWithNewOp<::mlir::func::CallOp>( - op, fsa, ::mlir::TypeRange(), ::mlir::ValueRange({dataUMR, opV})); - return ::mlir::success(); - } -}; - -/// Convert ::imex::distruntime::WaitOp into call to _idtr_wait -struct WaitOpPattern - : public ::mlir::OpRewritePattern<::imex::distruntime::WaitOp> { - using ::mlir::OpRewritePattern<::imex::distruntime::WaitOp>::OpRewritePattern; - - ::mlir::LogicalResult - matchAndRewrite(::imex::distruntime::WaitOp op, - ::mlir::PatternRewriter &rewriter) const override { - auto fsa = rewriter.getStringAttr("_idtr_wait"); - rewriter.replaceOpWithNewOp<::mlir::func::CallOp>( - op, fsa, ::mlir::TypeRange(), ::mlir::ValueRange{op.getHandle()}); - - return ::mlir::success(); - } -}; - struct CopyPermuteOpPattern : public ::mlir::OpRewritePattern<::imex::distruntime::CopyPermuteOp> { using ::mlir::OpRewritePattern< @@ -372,10 +182,9 @@ struct CopyPermuteOpPattern matchAndRewrite(::imex::distruntime::CopyPermuteOp op, ::mlir::PatternRewriter &rewriter) const override { auto lArray = op.getLArray(); - auto arType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(lArray.getType()); + auto arType = mlir::dyn_cast<::mlir::RankedTensorType>(lArray.getType()); auto resType = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getNlArray().getType()); + mlir::dyn_cast<::mlir::RankedTensorType>(op.getNlArray().getType()); if (!arType || !resType) { return ::mlir::failure(); } @@ -399,7 +208,7 @@ struct CopyPermuteOpPattern // create output array with target size auto nlArray = rewriter.create<::imex::ndarray::CreateOp>( loc, nlShape, ::imex::ndarray::fromMLIR(elType), nullptr, - resType.getEnvironments()); + resType.getEncoding()); auto idxType = rewriter.getIndexType(); auto teamC = rewriter.create<::mlir::arith::ConstantOp>( @@ -434,9 +243,8 @@ struct DistRuntimeToIDTRPass RuntimePrototypes::add_prototypes(builder, this->getOperation()); ::mlir::FrozenRewritePatternSet patterns; - insertPatterns(getContext(), patterns); + insertPatterns(getContext(), + patterns); (void)::mlir::applyPatternsAndFoldGreedily(this->getOperation(), patterns); }; // runOnOperation() @@ -455,55 +263,3 @@ extern "C" { int _idtr_nprocs(void *) __attribute__((weak)); int _idtr_prank(void *) __attribute__((weak)); } - -namespace imex { -namespace distruntime { - -static auto DNDA_NPROCS = getenv("DNDA_NPROCS"); -static auto DNDA_PRANK = getenv("DNDA_PRANK"); - -::mlir::OpFoldResult TeamSizeOp::fold(FoldAdaptor adaptor) { - // call runtime at compile time if available and team is constant - if (DNDA_NPROCS) { - auto np = std::stoi(DNDA_NPROCS); - ::mlir::Builder builder(getContext()); - return builder.getIndexAttr(np); - } - if (_idtr_nprocs != NULL) { - ::mlir::Builder builder(getContext()); - auto team = mlir::cast<::mlir::IntegerAttr>(adaptor.getTeam()).getInt(); - auto np = _idtr_nprocs(reinterpret_cast(team)); - return builder.getIndexAttr(np); - } - return nullptr; -} - -::mlir::OpFoldResult TeamMemberOp::fold(FoldAdaptor adaptor) { - // call runtime at compile time if available and team is constant - if (DNDA_PRANK) { - auto np = std::stoi(DNDA_PRANK); - ::mlir::Builder builder(getContext()); - return builder.getIndexAttr(np); - } - if (_idtr_prank != NULL) { - ::mlir::Builder builder(getContext()); - auto team = mlir::cast<::mlir::IntegerAttr>(adaptor.getTeam()).getInt(); - auto np = _idtr_prank(reinterpret_cast(team)); - return builder.getIndexAttr(np); - } - return nullptr; -} - -/// Materialize a single constant operation from a given attribute value with -/// the desired resultant type. -/// Ported from mlir::tensor dialect -mlir::Operation *imex::distruntime::DistRuntimeDialect::materializeConstant( - mlir::OpBuilder &builder, mlir::Attribute value, mlir::Type type, - mlir::Location loc) { - if (auto op = mlir::arith::ConstantOp::materialize(builder, value, type, loc)) - return op; - return nullptr; -} - -} // namespace distruntime -} // namespace imex diff --git a/lib/Dialect/DistRuntime/Transforms/OverlapCommAndCompute.cpp b/lib/Dialect/DistRuntime/Transforms/OverlapCommAndCompute.cpp deleted file mode 100644 index 34867cfab..000000000 --- a/lib/Dialect/DistRuntime/Transforms/OverlapCommAndCompute.cpp +++ /dev/null @@ -1,210 +0,0 @@ -//===- OverlapCommAndCompute.cpp - OverlapCommAndCompute Transform *- C++ -*-// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file This file implements overlapping communication and computation. -/// The pass tries to pull up ewops which are users of other ewops so that -/// similar, e.g. dependent, ewops are close together. -/// The main purpose of this is to allow separate groups of ewops from each -/// other so that asynchronous operations can effectively operate in the -/// background. For this, the pass pushes down WaitOps to the first use of the -/// data they protect. It is also necessary to push down SubViewOps to their -/// first use. -/// -//===----------------------------------------------------------------------===// - -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace imex { -#define GEN_PASS_DEF_OVERLAPCOMMANDCOMPUTE -#include -} // namespace imex - -namespace imex { -namespace ndarray { -std::vector<::mlir::Operation *> getSortedUsers(::mlir::DominanceInfo &dom, - ::mlir::Operation *op) { - std::vector<::mlir::Operation *> users(op->getUsers().begin(), - op->getUsers().end()); - std::sort(users.begin(), users.end(), ::imex::opOrderCmp(dom)); - return users; -} - -namespace { - -struct OverlapCommAndComputePass - : public impl::OverlapCommAndComputeBase { - - OverlapCommAndComputePass() = default; - - // pull up dependent ewops if possible - void pullDescendents(::std::vector<::mlir::Operation *> &grp) { - auto &dom = this->getAnalysis<::mlir::DominanceInfo>(); - - for (auto op : grp) { - ::std::vector<::mlir::Operation *> allOps; - ::mlir::Operation *curr = op; - - if (std::find(allOps.begin(), allOps.end(), op) == allOps.end()) { - allOps.emplace_back(op); - } - - // pull all users - for (auto user = op->getUsers().begin(); user != op->getUsers().end(); - ++user) { - if (::mlir::isa<::imex::ndarray::EWBinOp, ::imex::ndarray::EWUnyOp>( - *user) && - std::find(grp.begin(), grp.end(), *user) != grp.end() && - std::find(allOps.begin(), allOps.end(), *user) == allOps.end()) { - ::mlir::SmallVector<::mlir::Operation *> toBeMoved; - // store users and dependences for later sorting and move - if (canMoveAfter(dom, *user, curr, toBeMoved)) { - for (auto dop : toBeMoved) { - if (::mlir::isa<::imex::distruntime::GetHaloOp>(dop)) { - toBeMoved.clear(); - break; // FIXME can we do this less restrictive? - } - } - for (auto dop : toBeMoved) { - if (std::find(allOps.begin(), allOps.end(), dop) == - allOps.end()) { - allOps.emplace_back(dop); - } - } - curr = allOps.back(); - } - } - } - - // sort all deps and move to right after ewop - if (!allOps.empty()) { - std::sort(allOps.begin(), allOps.end(), ::imex::opOrderCmp(dom)); - curr = allOps.front(); - for (auto dop : allOps) { - if (dop != curr) { - assert(!::mlir::isa<::imex::distruntime::GetHaloOp>(dop)); - dop->moveAfter(curr); - curr = dop; - } - } - } - } - } - - // push operation down to first use - void pushDefiningOp(::mlir::Operation *op) { - auto &dom = this->getAnalysis<::mlir::DominanceInfo>(); - auto users = getSortedUsers(dom, op); - if (!users.empty()) { - op->moveBefore(users.front()); - } - } - - // collect all users of given value, excluding wait ops - static void appendUsers(::mlir::Value val, - ::mlir::SmallVector<::mlir::Operation *> &users) { - for (auto it = val.user_begin(); it != val.user_end(); ++it) { - auto op = *it; - if (!::mlir::isa<::imex::distruntime::WaitOp>(op)) { - // A cast op can be ignored as it is basically a no-op - // but the result needs to be tracked - if (auto castOp = ::mlir::dyn_cast<::imex::ndarray::CastOp>(op)) { - appendUsers(castOp.getDestination(), users); - } else { - users.push_back(op); - } - } - } - }; - - // From the WaitOps defining GetHaloOp get resulting halos. - // Move WaitOp to first use of any of the halos. - void pushWaitOp(::mlir::Operation *op) { - auto waitOp = ::mlir::cast<::imex::distruntime::WaitOp>(op); - auto asyncOp = waitOp.getHandle().getDefiningOp<::mlir::AsyncOpInterface>(); - assert(asyncOp); - - ::mlir::SmallVector<::mlir::Operation *> users; - for (auto d : asyncOp.getDependent()) { - appendUsers(d, users); - } - - // sort - auto &dom = this->getAnalysis<::mlir::DominanceInfo>(); - std::sort(users.begin(), users.end(), ::imex::opOrderCmp(dom)); - - // push WaitOp down - if (!users.empty()) { - op->moveBefore(users.front()); - } - // if there is no user, we still need the wait call. - } - - /// @brief group ewops, push out SubviewOps and WaitOps as much as possible. - /// Do not pull ewops over InsertSliceOps - void runOnOperation() override { - auto root = this->getOperation(); - ::std::vector<::mlir::Operation *> ewops, svops, waitops; - ::std::vector<::std::vector<::mlir::Operation *>> ewgroups; - - // find all ewops, WaitOps and Subviewops - // create groups of ewops separated by InsertSliceOps - root->walk([&](::mlir::Operation *op) { - if (::mlir::isa<::imex::ndarray::EWBinOp, ::imex::ndarray::EWUnyOp>(op)) { - ewops.emplace_back(op); - } else if (::mlir::isa<::imex::ndarray::SubviewOp>(op)) { - svops.emplace_back(op); - } else if (::mlir::isa<::imex::distruntime::WaitOp>(op)) { - waitops.emplace_back(op); - } else if (::mlir::isa<::imex::ndarray::InsertSliceOp>(op) && - !ewops.empty()) { - ewgroups.push_back(std::move(ewops)); - assert(ewops.empty()); - } - }); - if (!ewops.empty()) { - ewgroups.emplace_back(std::move(ewops)); - } - - // within each group, pull up dependent ewops - for (auto grp : ewgroups) { - pullDescendents(grp); - grp.clear(); - } - ewgroups.clear(); - - // push down SubviewOPs to their first use - for (auto op : svops) { - pushDefiningOp(op); - } - svops.clear(); - - // push down WaitOps to the first use of the data they protect - for (auto op : waitops) { - pushWaitOp(op); - } - waitops.clear(); - } -}; -} // namespace -} // namespace ndarray - -/// Create a pass to eliminate Dist ops -std::unique_ptr<::mlir::Pass> createOverlapCommAndComputePass() { - return std::make_unique<::imex::ndarray::OverlapCommAndComputePass>(); -} - -} // namespace imex diff --git a/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp b/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp index d023fb056..329c5a468 100644 --- a/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp +++ b/lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp @@ -6,16 +6,14 @@ // //===----------------------------------------------------------------------===// +#include "imex/Dialect/NDArray/IR/NDArrayOps.h" #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" -#include "imex/Dialect/NDArray/IR/NDArrayOps.h" -#include "imex/Dialect/NDArray/IR/NDArrayOps.h" -#include "imex/Dialect/Dist/IR/DistOps.h" #include "mlir/IR/DialectRegistry.h" #include "llvm/Support/Debug.h" -#include -#include #include +#include +#include #define DEBUG_TYPE "ndarray-sharding-impl" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") @@ -23,8 +21,8 @@ using namespace mlir; using namespace mlir::mesh; -using imex::easyIdx; using imex::easyI64; +using imex::easyIdx; namespace imex { namespace ndarray { @@ -57,22 +55,29 @@ static SmallVector getMyMultiIndex(OpBuilder &b, return b.create(mesh.getLoc(), mesh).getResult(); } +template +static auto getBaseShardDimOff(T shard, T numShards, T extend, T zero) { + return (shard * (extend / numShards)) + + (shard - (numShards - (extend % numShards))).max(zero); +}; + // Sharding of tensor.empty -template +template struct OffsetSizeAndStrideShardingInterface - : public ShardingInterface::ExternalModel { - - SmallVector getLoopIteratorTypes(::mlir::Operation *op) const { + : public ShardingInterface::ExternalModel { + + SmallVector + getLoopIteratorTypes(::mlir::Operation *op) const { LLVM_DEBUG(DBGS() << "getLoopIteratorTypes\n"); Value val = op->getOperand(0); auto type = dyn_cast(val.getType()); if (!type) return {}; SmallVector types(type.getRank(), - utils::IteratorType::parallel); + utils::IteratorType::parallel); return types; } - + SmallVector getIndexingMaps(::mlir::Operation *op) const { LLVM_DEBUG(DBGS() << "getIndexingMaps\n"); MLIRContext *ctx = op->getContext(); @@ -224,9 +229,8 @@ struct OffsetSizeAndStrideShardingInterface auto targetOff = easyIdx(loc, builder, slcOffs[dim]) + easyIdx(loc, builder, myOffAndSize.getResult(0)) * easyIdx(loc, builder, slcStrides[dim]); - auto shardOff = - imex::dist::getBaseShardDimOff(myID, numShards, extend, zero) - - easyIdx(loc, builder, haloSizes[shardedDim * 2]); + auto shardOff = getBaseShardDimOff(myID, numShards, extend, zero) - + easyIdx(loc, builder, haloSizes[shardedDim * 2]); return {(targetOff - shardOff).get(), myOffAndSize.getResult(1)}; } @@ -289,7 +293,9 @@ struct OffsetSizeAndStrideShardingInterface } }; -struct SubviewShardingInterface : public OffsetSizeAndStrideShardingInterface { +struct SubviewShardingInterface + : public OffsetSizeAndStrideShardingInterface { LogicalResult addShardingAnnotations(::mlir::Operation *op, OpBuilder &b, const ShardingOption &shardingOption) const { @@ -299,16 +305,23 @@ struct SubviewShardingInterface : public OffsetSizeAndStrideShardingInterfacegetOpOperand(0), b); + maybeInsertSourceShardingAnnotation(srcShardOp.getSharding(), + op->getOpOperand(0), b); auto sharding = getShardedDimsOffsetsSharding(svop.getSource(), svop); - if(failed(sharding)) return failure(); + if (failed(sharding)) + return failure(); maybeInsertTargetShardingAnnotation(sharding.value(), op->getResult(0), b); return success(); } - LogicalResult spmdize(::mlir::Operation *op, ArrayRef spmdizedOperands, ArrayRef operandShardings, ArrayRef resultShardings, IRMapping&spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) const { + LogicalResult spmdize(::mlir::Operation *op, ArrayRef spmdizedOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, + IRMapping &spmdizationMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) const { if (resultShardings.size() != 1) { return failure(); } @@ -333,8 +346,12 @@ struct SubviewShardingInterface : public OffsetSizeAndStrideShardingInterface { - LogicalResult addShardingAnnotations(::mlir::Operation *op, OpBuilder &b, const ShardingOption &shardingOption) const { +struct InsertSliceShardingInterface + : public OffsetSizeAndStrideShardingInterface< + InsertSliceShardingInterface, imex::ndarray::InsertSliceOp> { + LogicalResult + addShardingAnnotations(::mlir::Operation *op, OpBuilder &b, + const ShardingOption &shardingOption) const { LLVM_DEBUG(DBGS() << "addShardingAnnotations\n"); auto svop = cast(op); @@ -355,7 +372,12 @@ struct InsertSliceShardingInterface : public OffsetSizeAndStrideShardingInterfac return success(); } - LogicalResult spmdize(::mlir::Operation *op, ArrayRef spmdizedOperands, ArrayRef operandShardings, ArrayRef resultShardings, IRMapping&spmdizationMap, SymbolTableCollection &symbolTableCollection, OpBuilder &builder) const { + LogicalResult spmdize(::mlir::Operation *op, ArrayRef spmdizedOperands, + ArrayRef operandShardings, + ArrayRef resultShardings, + IRMapping &spmdizationMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) const { if (resultShardings.size() != 0) { return failure(); } @@ -394,7 +416,8 @@ struct InsertSliceShardingInterface : public OffsetSizeAndStrideShardingInterfac } // namespace void registerShardingInterfaceExternalModels(mlir::DialectRegistry ®istry) { - registry.addExtension(+[](MLIRContext *ctx, imex::ndarray::NDArrayDialect *dialect) { + registry.addExtension(+[](MLIRContext *ctx, + imex::ndarray::NDArrayDialect *dialect) { SubviewOp::template attachInterface(*ctx); InsertSliceOp::template attachInterface(*ctx); }); diff --git a/lib/Dialect/NDArray/IR/CMakeLists.txt b/lib/Dialect/NDArray/IR/CMakeLists.txt index 270f75893..947097f8c 100644 --- a/lib/Dialect/NDArray/IR/CMakeLists.txt +++ b/lib/Dialect/NDArray/IR/CMakeLists.txt @@ -4,12 +4,8 @@ add_imex_dialect_library(IMEXNDArrayDialect InsertSliceOp.cpp CreateOp.cpp LinSpaceOp.cpp - CastOp.cpp - DimOp.cpp - EWBinOp.cpp - EWUnyOp.cpp - PermuteDimsOp.cpp DeleteOp.cpp + CastElemTypeOp.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/mlir/Dialect/NDArray diff --git a/lib/Dialect/NDArray/IR/CastElemTypeOp.cpp b/lib/Dialect/NDArray/IR/CastElemTypeOp.cpp new file mode 100644 index 000000000..f6a33267b --- /dev/null +++ b/lib/Dialect/NDArray/IR/CastElemTypeOp.cpp @@ -0,0 +1,86 @@ +//===- CastOp.cpp - NDArray dialect --------------------------*- C++ -*-===// +// +// Copyright 2023 Intel Corporation +// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the CastElemTypeOp of the NDArray dialect. +/// +//===----------------------------------------------------------------------===// + +#include + +/// Pattern to rewrite a CastElemTypeOp replacing dynamically shaped inputs +/// by statically shaped inputs if they are defined by an appropriate CastOp. +class CastElemTypeOpInputCanonicalizer final + : public mlir::OpRewritePattern<::imex::ndarray::CastElemTypeOp> { +public: + using mlir::OpRewritePattern< + ::imex::ndarray::CastElemTypeOp>::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(::imex::ndarray::CastElemTypeOp op, + ::mlir::PatternRewriter &rewriter) const override { + + if (!llvm::isa<::mlir::RankedTensorType>(op.getResult().getType())) { + return mlir::failure(); + }; + + auto src = op.getInput(); + auto srcNDTyp = mlir::dyn_cast<::mlir::RankedTensorType>(src.getType()); + auto defOp = src.getDefiningOp<::mlir::tensor::CastOp>(); + if (!srcNDTyp || srcNDTyp.hasStaticShape() || !defOp) { + return mlir::failure(); + } + auto defOpSrc = defOp.getSource(); + auto defSrcNDTyp = + mlir::dyn_cast<::mlir::RankedTensorType>(defOpSrc.getType()); + if (!defSrcNDTyp || !defSrcNDTyp.hasStaticShape()) { + return mlir::failure(); + } + rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, defOpSrc); }); + return ::mlir::success(); + } +}; + +/// Pattern to rewrite a CastElemTypeOp replacing dynamically shaped result type +/// by statically shaped result type if input is statically shaped. +class CastElemTypeOpResultCanonicalizer final + : public mlir::OpRewritePattern<::imex::ndarray::CastElemTypeOp> { +public: + using mlir::OpRewritePattern< + ::imex::ndarray::CastElemTypeOp>::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(::imex::ndarray::CastElemTypeOp op, + ::mlir::PatternRewriter &rewriter) const override { + + auto src = op.getInput(); + auto srcNDTyp = mlir::dyn_cast<::mlir::RankedTensorType>(src.getType()); + auto resNDTyp = + mlir::dyn_cast<::mlir::RankedTensorType>(op.getResult().getType()); + if (!(srcNDTyp && resNDTyp && srcNDTyp.hasStaticShape() && + !resNDTyp.hasStaticShape())) { + return mlir::failure(); + } + + auto resShape = srcNDTyp.getShape(); + auto resTyp = resNDTyp.cloneWith(resShape, resNDTyp.getElementType()); + auto newOp = rewriter.create<::imex::ndarray::CastElemTypeOp>(op->getLoc(), + resTyp, src); + rewriter.replaceOpWithNewOp<::mlir::tensor::CastOp>(op, resNDTyp, newOp); + + return ::mlir::success(); + } +}; + +void imex::ndarray::CastElemTypeOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &results, mlir::MLIRContext *context) { + results + .add( + context); +} \ No newline at end of file diff --git a/lib/Dialect/NDArray/IR/CastOp.cpp b/lib/Dialect/NDArray/IR/CastOp.cpp deleted file mode 100644 index 52bab7bbf..000000000 --- a/lib/Dialect/NDArray/IR/CastOp.cpp +++ /dev/null @@ -1,229 +0,0 @@ -//===- CastOp.cpp - NDArray dialect --------------------------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements the CastOp of the NDArray dialect. -/// -//===----------------------------------------------------------------------===// - -#include - -/// Ported from mlir::tensor::CastOp -bool imex::ndarray::CastOp::areCastCompatible(mlir::TypeRange inputs, - mlir::TypeRange outputs) { - if (inputs.size() != 1 || outputs.size() != 1) - return false; - mlir::Type a = inputs.front(), b = outputs.front(); - auto aT = llvm::dyn_cast(a); - auto bT = llvm::dyn_cast(b); - if (!aT || !bT) - return false; - - if (aT.getElementType() != bT.getElementType()) - return false; - - return mlir::succeeded( - mlir::verifyCompatibleShape(aT.getTensorType(), bT.getTensorType())); -} - -/// Ported from mlir::tensor -bool imex::ndarray::canFoldIntoConsumerOp(imex::ndarray::CastOp castOp) { - if (!castOp) - return false; - - // Can fold if the source of cast has at least as much static information as - // its results. - return mlir::tensor::preservesStaticInformation( - castOp.getType().getTensorType(), - castOp.getSource().getType().getTensorType()); -} -bool imex::ndarray::canFoldIntoConsumerOp(mlir::tensor::CastOp castOp) { - if (!castOp) - return false; - - // Can fold if the source of cast has at least as much static information as - // its results. - return mlir::tensor::preservesStaticInformation( - castOp.getType(), - castOp.getSource().getType()); -} - -/// Ported from mlir::tensor -mlir::LogicalResult imex::ndarray::foldArrayCast(mlir::Operation *op) { - bool folded = false; - for (mlir::OpOperand &operand : op->getOpOperands()) { - auto castOp = operand.get().getDefiningOp(); - if (castOp && imex::ndarray::canFoldIntoConsumerOp(castOp)) { - operand.set(castOp.getOperand()); - folded = true; - } - } - return mlir::success(folded); -} - -/// Compute a TensorType that has the joined shape knowledge of the two -/// given TensorTypes. The element types need to match. -/// Ported from mlir::tensor -static mlir::TensorType joinShapes(mlir::TensorType one, mlir::TensorType two) { - assert(one.getElementType() == two.getElementType()); - - if (!one.hasRank()) - return two; - if (!two.hasRank()) - return one; - - int64_t rank = one.getRank(); - if (rank != two.getRank()) - return {}; - - mlir::SmallVector join; - join.reserve(rank); - for (int64_t i = 0; i < rank; ++i) { - if (one.isDynamicDim(i)) { - join.push_back(two.getDimSize(i)); - continue; - } - if (two.isDynamicDim(i)) { - join.push_back(one.getDimSize(i)); - continue; - } - if (one.getDimSize(i) != two.getDimSize(i)) - return {}; - join.push_back(one.getDimSize(i)); - } - return mlir::RankedTensorType::get(join, one.getElementType()); -} - -/// Replaces chains of two ndarray.cast operations by a single ndarray.cast -/// operation if doing so does not remove runtime constraints. -/// Ported from mlir::tensor::CastOp -struct ChainedNDArrayCast - : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(imex::ndarray::CastOp op, - mlir::PatternRewriter &rewriter) const final { - auto arCastOperand = op.getOperand().getDefiningOp(); - - if (!arCastOperand) - return mlir::failure(); - - // infer cast shapes with array types - auto sourceType = llvm::cast( - arCastOperand.getOperand().getType()) - .getTensorType(); - auto intermediateType = - llvm::cast(arCastOperand.getType()) - .getTensorType(); - auto resultType = - llvm::cast(op.getType()).getTensorType(); - - // We can remove the intermediate cast if joining all three produces the - // same result as just joining the source and result shapes.in - auto tmpJoin = joinShapes(sourceType, intermediateType); - if (!tmpJoin) - return mlir::failure(); - auto firstJoin = joinShapes(tmpJoin, resultType); - if (!firstJoin) - return mlir::failure(); - - // The newJoin always exists if the above join exists, it might just contain - // less information. If so, we cannot drop the intermediate cast, as doing - // so would remove runtime checks. - auto newJoin = joinShapes(sourceType, resultType); - if (firstJoin != newJoin) - return mlir::failure(); - - auto sourcePTType = - mlir::dyn_cast(op.getSource().getType()); - auto resultPTType = sourcePTType.cloneWith(resultType.getShape(), - resultType.getElementType()); - ; - rewriter.replaceOpWithNewOp( - op, resultPTType, arCastOperand.getOperand()); - return mlir::success(); - } -}; - -void imex::ndarray::CastOp::getCanonicalizationPatterns( - mlir::RewritePatternSet &results, mlir::MLIRContext *context) { - results.add(context); -} - -/// Pattern to rewrite a CastElemTypeOp replacing dynamically shaped inputs -/// by statically shaped inputs if they are defined by an appropriate CastOp. -class CastElemTypeOpInputCanonicalizer final - : public mlir::OpRewritePattern<::imex::ndarray::CastElemTypeOp> { -public: - using mlir::OpRewritePattern< - ::imex::ndarray::CastElemTypeOp>::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(::imex::ndarray::CastElemTypeOp op, - ::mlir::PatternRewriter &rewriter) const override { - - if (!llvm::isa<::imex::ndarray::NDArrayType>(op.getResult().getType())) { - return mlir::failure(); - }; - - auto src = op.getInput(); - auto srcNDTyp = mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType()); - auto defOp = src.getDefiningOp<::imex::ndarray::CastOp>(); - if (!srcNDTyp || srcNDTyp.hasStaticShape() || !defOp) { - return mlir::failure(); - } - auto defOpSrc = defOp.getSource(); - auto defSrcNDTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(defOpSrc.getType()); - if (!defSrcNDTyp || !defSrcNDTyp.hasStaticShape()) { - return mlir::failure(); - } - rewriter.modifyOpInPlace(op, [&]() { op->setOperand(0, defOpSrc); }); - return ::mlir::success(); - } -}; - -/// Pattern to rewrite a CastElemTypeOp replacing dynamically shaped result type -/// by statically shaped result type if input is statically shaped. -class CastElemTypeOpResultCanonicalizer final - : public mlir::OpRewritePattern<::imex::ndarray::CastElemTypeOp> { -public: - using mlir::OpRewritePattern< - ::imex::ndarray::CastElemTypeOp>::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(::imex::ndarray::CastElemTypeOp op, - ::mlir::PatternRewriter &rewriter) const override { - - auto src = op.getInput(); - auto srcNDTyp = mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType()); - auto resNDTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getResult().getType()); - if (!(srcNDTyp && resNDTyp && srcNDTyp.hasStaticShape() && - !resNDTyp.hasStaticShape())) { - return mlir::failure(); - } - - auto resShape = srcNDTyp.getShape(); - auto resTyp = resNDTyp.cloneWith(resShape, resNDTyp.getElementType()); - auto newOp = rewriter.create<::imex::ndarray::CastElemTypeOp>(op->getLoc(), - resTyp, src); - rewriter.replaceOpWithNewOp<::imex::ndarray::CastOp>(op, resNDTyp, newOp); - - return ::mlir::success(); - } -}; - -void imex::ndarray::CastElemTypeOp::getCanonicalizationPatterns( - mlir::RewritePatternSet &results, mlir::MLIRContext *context) { - results - .add( - context); -} diff --git a/lib/Dialect/NDArray/IR/CreateOp.cpp b/lib/Dialect/NDArray/IR/CreateOp.cpp index 13998b218..f1848e25c 100644 --- a/lib/Dialect/NDArray/IR/CreateOp.cpp +++ b/lib/Dialect/NDArray/IR/CreateOp.cpp @@ -54,8 +54,8 @@ class CreateOpConstantArgumentFolder final auto newOp = rewriter.create( loc, newReturnType, shape, createOp.getValue()); // cast to original type - rewriter.replaceOpWithNewOp(createOp, oldReturnType, - newOp); + rewriter.replaceOpWithNewOp<::mlir::tensor::CastOp>(createOp, oldReturnType, + newOp); return mlir::success(); } diff --git a/lib/Dialect/NDArray/IR/DimOp.cpp b/lib/Dialect/NDArray/IR/DimOp.cpp deleted file mode 100644 index c0555382e..000000000 --- a/lib/Dialect/NDArray/IR/DimOp.cpp +++ /dev/null @@ -1,169 +0,0 @@ -//===- DimOp.cpp - NDArray dialect --------------------------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements the DimOp of the NDArray dialect. -/// Ported from NTensor. -/// -//===----------------------------------------------------------------------===// - -#include -#include -#include -#include -#include -#include -#include - -/// Materialize a single constant operation from a given attribute value with -/// the desired resultant type. -/// Ported from mlir::tensor dialect -mlir::Operation *imex::ndarray::NDArrayDialect::materializeConstant( - mlir::OpBuilder &builder, mlir::Attribute value, mlir::Type type, - mlir::Location loc) { - if (auto op = mlir::arith::ConstantOp::materialize(builder, value, type, loc)) - return op; - return nullptr; -} - -void imex::ndarray::DimOp::getAsmResultNames( - llvm::function_ref setNameFn) { - setNameFn(getResult(), "dim"); -} - -void imex::ndarray::DimOp::build(mlir::OpBuilder &builder, - mlir::OperationState &result, - mlir::Value source, int64_t index) { - auto loc = result.location; - auto indexValue = builder.create(loc, index); - build(builder, result, source, indexValue); -} - -std::optional imex::ndarray::DimOp::getConstantIndex() { - if (auto val = mlir::getConstantIntValue(getIndex())) - return *val; - - return {}; -} - -mlir::Speculation::Speculatability imex::ndarray::DimOp::getSpeculatability() { - auto constantIndex = getConstantIndex(); - if (!constantIndex) - return mlir::Speculation::NotSpeculatable; - - auto rankedType = - mlir::dyn_cast(getSource().getType()); - if (!rankedType) - return mlir::Speculation::NotSpeculatable; - - // The verifier rejects operations that violate this assertion. - assert(constantIndex < rankedType.getRank()); - return mlir::Speculation::Speculatable; -} - -/// Ported from mlir::tensor::DimOp -mlir::OpFoldResult imex::ndarray::DimOp::fold(FoldAdaptor adaptor) { - // All forms of folding require a known index. - auto index = llvm::dyn_cast_if_present(adaptor.getIndex()); - if (!index) - return {}; - - // Folding for unranked types is not supported. - auto ndarrayType = - llvm::dyn_cast(getSource().getType()); - if (!ndarrayType) - return {}; - - // Out of bound indices produce undefined behavior but are still valid IR. - // Don't choke on them. - int64_t indexVal = index.getInt(); - if (indexVal < 0 || indexVal >= ndarrayType.getRank()) - return {}; - - // Fold if the shape extent along the given index is known. - if (!ndarrayType.isDynamicDim(index.getInt())) { - mlir::Builder builder(getContext()); - return builder.getIndexAttr(ndarrayType.getShape()[index.getInt()]); - } - - mlir::Operation *definingOp = getSource().getDefiningOp(); - - // The size at the given index is now known to be a dynamic size. - unsigned unsignedIndex = index.getValue().getZExtValue(); - - if (auto sliceOp = - mlir::dyn_cast_or_null(definingOp)) { - // Fold only for non-rank reduced ops. For the rank-reduced version, rely on - // `resolve-shaped-type-result-dims` pass. - if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() && - sliceOp.isDynamicSize(unsignedIndex)) { - return {sliceOp.getDynamicSize(unsignedIndex)}; - } - } - - // dim(cast) -> dim - if (succeeded(imex::ndarray::foldArrayCast(*this))) - return getResult(); - - return {}; -} - -namespace { -/// Fold dim of a cast into the dim of the source of the ndarray cast. -/// Ported from mlir::tensor::DimOp -struct DimOfCastOp : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(imex::ndarray::DimOp dimOp, - mlir::PatternRewriter &rewriter) const override { - auto castOp = dimOp.getSource().getDefiningOp(); - if (!castOp) - return mlir::failure(); - mlir::Value newSource = castOp.getOperand(); - rewriter.replaceOpWithNewOp(dimOp, newSource, - dimOp.getIndex()); - return mlir::success(); - } -}; - -// TODO: upstream -struct LinalgGenericDimPropagate - : public mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(mlir::tensor::DimOp op, - mlir::PatternRewriter &rewriter) const override { - auto src = op.getSource(); - auto generic = src.getDefiningOp(); - if (!generic) - return mlir::failure(); - - assert(generic.getOutputs().size() == generic.getResults().size()); - auto outIndex = [&]() -> size_t { - for (auto [i, out] : llvm::enumerate(generic.getResults())) { - if (out == src) - return i; - } - llvm_unreachable("Invalid result"); - }(); - - auto out = generic.getOutputs()[outIndex]; - - rewriter.replaceOpWithNewOp(op, out, op.getIndex()); - return mlir::success(); - } -}; -} // namespace - -void imex::ndarray::DimOp::getCanonicalizationPatterns( - ::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context) { - results.insert(context); -} diff --git a/lib/Dialect/NDArray/IR/EWBinOp.cpp b/lib/Dialect/NDArray/IR/EWBinOp.cpp deleted file mode 100644 index 3317b2271..000000000 --- a/lib/Dialect/NDArray/IR/EWBinOp.cpp +++ /dev/null @@ -1,77 +0,0 @@ -//===- EWBinOp.cpp - NDArray dialect ---------------------------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements the EWBinOp of the NDArray dialect. -/// -//===----------------------------------------------------------------------===// - -#include "EWOp.h" - -namespace { -/// Pattern to rewrite a ewbin op replacing dynamically shaped inputs -/// by statically shaped inputs if they are defined by an appropriate castop. -class EWBinOpInputCanonicalizer final - : public mlir::OpRewritePattern<::imex::ndarray::EWBinOp> { -public: - using mlir::OpRewritePattern<::imex::ndarray::EWBinOp>::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(::imex::ndarray::EWBinOp op, - ::mlir::PatternRewriter &rewriter) const override { - - if (!llvm::isa<::imex::ndarray::NDArrayType>(op.getResult().getType())) { - return mlir::failure(); - }; - - bool succ = replaceOperandInplaceWithCast(rewriter, 0, op.getLhs(), op) || - replaceOperandInplaceWithCast(rewriter, 1, op.getRhs(), op); - return succ ? ::mlir::success() : ::mlir::failure(); - } -}; - -/// Pattern to rewrite a ewbin op replacing dynamically shaped result type -/// by statically shaped result type if both inputs are statically shaped. -class EWBinOpResultCanonicalizer final - : public mlir::OpRewritePattern<::imex::ndarray::EWBinOp> { -public: - using mlir::OpRewritePattern<::imex::ndarray::EWBinOp>::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(::imex::ndarray::EWBinOp op, - ::mlir::PatternRewriter &rewriter) const override { - - auto lhs = op.getLhs(); - auto lPtTyp = mlir::dyn_cast<::imex::ndarray::NDArrayType>(lhs.getType()); - auto rhs = op.getRhs(); - auto rPtTyp = mlir::dyn_cast<::imex::ndarray::NDArrayType>(rhs.getType()); - auto resPtTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getResult().getType()); - if (!(lPtTyp && rPtTyp && resPtTyp && lPtTyp.hasStaticShape() && - rPtTyp.hasStaticShape() && !resPtTyp.hasStaticShape())) { - return mlir::failure(); - } - - auto outShape = ::imex::broadcast(lPtTyp.getShape(), rPtTyp.getShape()); - auto outTyp = resPtTyp.cloneWith(outShape, resPtTyp.getElementType()); - - auto nOp = rewriter.create<::imex::ndarray::EWBinOp>(op->getLoc(), outTyp, - op.getOp(), lhs, rhs); - rewriter.replaceOpWithNewOp<::imex::ndarray::CastOp>(op, resPtTyp, nOp); - - return ::mlir::success(); - } -}; - -} // namespace - -void imex::ndarray::EWBinOp::getCanonicalizationPatterns( - mlir::RewritePatternSet &results, mlir::MLIRContext *context) { - results.add(context); -} diff --git a/lib/Dialect/NDArray/IR/EWOp.h b/lib/Dialect/NDArray/IR/EWOp.h deleted file mode 100644 index 97e057128..000000000 --- a/lib/Dialect/NDArray/IR/EWOp.h +++ /dev/null @@ -1,36 +0,0 @@ -//===-- EWOp.h - NDArray pass details ---------------------------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This header file defines utilities for ew op canonicalization. -/// -//===----------------------------------------------------------------------===// - -#include -#include - -/// @brief replace operand of op with defining cast op -/// @return true if replacement succeeded, false otherwise -template -bool replaceOperandInplaceWithCast(::mlir::PatternRewriter &rewriter, - unsigned idx, ::mlir::Value arg, OpType op) { - auto ptTyp = mlir::dyn_cast<::imex::ndarray::NDArrayType>(arg.getType()); - if (ptTyp && !ptTyp.hasStaticShape()) { - auto defOp = arg.getDefiningOp<::imex::ndarray::CastOp>(); - if (defOp) { - auto src = defOp.getSource(); - auto srcPtTyp = mlir::cast<::imex::ndarray::NDArrayType>(src.getType()); - if (srcPtTyp.hasStaticShape()) { - rewriter.modifyOpInPlace(op, [&]() { op->setOperand(idx, src); }); - return true; - } - } - } - return false; -} diff --git a/lib/Dialect/NDArray/IR/EWUnyOp.cpp b/lib/Dialect/NDArray/IR/EWUnyOp.cpp deleted file mode 100644 index 444a30761..000000000 --- a/lib/Dialect/NDArray/IR/EWUnyOp.cpp +++ /dev/null @@ -1,75 +0,0 @@ -//===- EWUnyOp.cpp - NDArray dialect ---------------------------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements the EWUnyOp of the NDArray dialect. -/// -//===----------------------------------------------------------------------===// - -#include "EWOp.h" - -namespace { -/// Pattern to rewrite a EWUnyOp replacing dynamically shaped inputs -/// by statically shaped inputs if they are defined by an appropriate castop. -class EWUnyOpInputCanonicalizer final - : public mlir::OpRewritePattern<::imex::ndarray::EWUnyOp> { -public: - using mlir::OpRewritePattern<::imex::ndarray::EWUnyOp>::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(::imex::ndarray::EWUnyOp op, - ::mlir::PatternRewriter &rewriter) const override { - - if (!llvm::isa<::imex::ndarray::NDArrayType>(op.getResult().getType())) { - return mlir::failure(); - }; - - return replaceOperandInplaceWithCast(rewriter, 0, op.getSrc(), op) - ? ::mlir::success() - : ::mlir::failure(); - } -}; - -/// Pattern to rewrite a EWUnyOp replacing dynamically shaped result type -/// by statically shaped result type if input is statically shaped. -class EWUnyOpResultCanonicalizer final - : public mlir::OpRewritePattern<::imex::ndarray::EWUnyOp> { -public: - using mlir::OpRewritePattern<::imex::ndarray::EWUnyOp>::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(::imex::ndarray::EWUnyOp op, - ::mlir::PatternRewriter &rewriter) const override { - - auto src = op.getSrc(); - auto srcPtTyp = mlir::dyn_cast<::imex::ndarray::NDArrayType>(src.getType()); - auto resPtTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getResult().getType()); - if (!(srcPtTyp && resPtTyp && srcPtTyp.hasStaticShape() && - !resPtTyp.hasStaticShape())) { - return mlir::failure(); - } - - auto outShape = srcPtTyp.getShape(); - auto outTyp = resPtTyp.cloneWith(outShape, resPtTyp.getElementType()); - - auto nOp = rewriter.create<::imex::ndarray::EWUnyOp>(op->getLoc(), outTyp, - op.getOp(), src); - rewriter.replaceOpWithNewOp<::imex::ndarray::CastOp>(op, resPtTyp, nOp); - - return ::mlir::success(); - } -}; - -} // namespace - -void imex::ndarray::EWUnyOp::getCanonicalizationPatterns( - mlir::RewritePatternSet &results, mlir::MLIRContext *context) { - results.add(context); -} diff --git a/lib/Dialect/NDArray/IR/InsertSliceOp.cpp b/lib/Dialect/NDArray/IR/InsertSliceOp.cpp index b6de2d89f..c95e13934 100644 --- a/lib/Dialect/NDArray/IR/InsertSliceOp.cpp +++ b/lib/Dialect/NDArray/IR/InsertSliceOp.cpp @@ -152,7 +152,7 @@ class InsertSliceOpZeroFolder final #if 0 // FIXME auto srcTyp = ::mlir::dyn_cast( insertSliceOp.getSource().getType()); - if (srcTyp && srcTyp.hasZeroSize()) { + if (srcTyp && hasZeroSize(srcTyp.getShape())) { if (insertSliceOp->getNumResults() == 0) { rewriter.eraseOp(insertSliceOp); } else { @@ -230,8 +230,8 @@ struct InsertSliceOpCastFolder final return mlir::failure(); auto getSourceOfCastOp = [](mlir::Value v) -> std::optional { - auto castOp = v.getDefiningOp(); - if (!castOp || !imex::ndarray::canFoldIntoConsumerOp(castOp)) + auto castOp = v.getDefiningOp(); + if (!castOp || !mlir::tensor::canFoldIntoConsumerOp(castOp)) return std::nullopt; return castOp.getSource(); }; @@ -253,7 +253,7 @@ struct InsertSliceOpCastFolder final if (hasReturnValue && (dst.getType() != insertSliceOp.getDestinationType())) { - replacement = rewriter.create( + replacement = rewriter.create( insertSliceOp.getLoc(), insertSliceOp.getDestinationType(), replacement->getResult(0)); } diff --git a/lib/Dialect/NDArray/IR/LinSpaceOp.cpp b/lib/Dialect/NDArray/IR/LinSpaceOp.cpp index b252b2053..4aec97a66 100644 --- a/lib/Dialect/NDArray/IR/LinSpaceOp.cpp +++ b/lib/Dialect/NDArray/IR/LinSpaceOp.cpp @@ -46,8 +46,8 @@ class LinSpaceOpConstantArgumentFolder final auto newOp = rewriter.create( loc, newReturnType, op.getStart(), op.getStop(), num, op.getEndpoint()); // cast to original type - rewriter.replaceOpWithNewOp(op, oldReturnType, - newOp); + rewriter.replaceOpWithNewOp<::mlir::tensor::CastOp>(op, oldReturnType, + newOp); return mlir::success(); } diff --git a/lib/Dialect/NDArray/IR/NDArrayOps.cpp b/lib/Dialect/NDArray/IR/NDArrayOps.cpp index 3a299f483..316c14ca8 100644 --- a/lib/Dialect/NDArray/IR/NDArrayOps.cpp +++ b/lib/Dialect/NDArray/IR/NDArrayOps.cpp @@ -27,6 +27,10 @@ void NDArrayDialect::initialize() { #define GET_TYPEDEF_LIST #include >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include + >(); addOperations< #define GET_OP_LIST #include @@ -36,97 +40,6 @@ void NDArrayDialect::initialize() { } // namespace ndarray } // namespace imex -namespace imex { -namespace ndarray { - -NDArrayType NDArrayType::get(::mlir::MLIRContext *context, - ::llvm::ArrayRef shape, - ::mlir::Type elementType, - ::llvm::ArrayRef<::mlir::Attribute> environments, - ::mlir::StringAttr layout) { - ::mlir::SmallVector<::mlir::Attribute> envs(environments); - struct { - bool operator()(::mlir::Attribute a_, ::mlir::Attribute b_) const { - return ::mlir::hash_value(a_) < ::mlir::hash_value(b_); - } - } attrSort; - std::sort(envs.begin(), envs.end(), attrSort); - return Base::get(context, std::move(shape), std::move(elementType), - std::move(envs), std::move(layout)); -} - -NDArrayType NDArrayType::get(::llvm::ArrayRef shape, - ::mlir::Type elementType, - ::mlir::ArrayRef<::mlir::Attribute> environments, - std::optional<::llvm::StringRef> layout) { - auto ctx = elementType.getContext(); - auto l = - layout ? ::mlir::StringAttr::get(ctx, *layout) : ::mlir::StringAttr{}; - return get(ctx, shape, elementType, environments, l); -} - -NDArrayType NDArrayType::get(::llvm::ArrayRef shape, - ::mlir::Type elementType, - ::mlir::ArrayRef<::mlir::Attribute> environments, - ::mlir::StringAttr layout) { - auto ctx = elementType.getContext(); - return get(ctx, shape, elementType, environments, layout); -} - -::mlir::MemRefType NDArrayType::getMemRefType(::mlir::Value val) const { - if (val) { - auto defOp = val.getDefiningOp<::mlir::bufferization::ToTensorOp>(); - if (defOp) { - auto ty = - defOp.getMemref().getType().cloneWith(getShape(), getElementType()); - return mlir::cast<::mlir::MemRefType>(ty); - } - } - return ::imex::getMemRefType(getContext(), getShape(), getElementType()); -} - -::mlir::RankedTensorType NDArrayType::getTensorType() const { - return ::imex::getTensorType(getContext(), getShape(), getElementType()); -} - -NDArrayType NDArrayType::cloneWithDynDims() const { - if (hasZeroSize() || hasUnitSize()) { - return mlir::cast(cloneWith(getShape(), getElementType())); - } - return NDArrayType::get( - ::mlir::SmallVector(getRank(), ::mlir::ShapedType::kDynamic), - getElementType(), getEnvironments(), getLayout()); -} - -} // namespace ndarray -} // namespace imex - -bool imex::ndarray::NDArrayBase::hasRank() const { return true; } - -llvm::ArrayRef imex::ndarray::NDArrayBase::getShape() const { - return mlir::cast(*this).getShape(); -} - -imex::ndarray::NDArrayBase imex::ndarray::NDArrayBase::cloneWith( - std::optional> shape, Type elementType) const { - auto t = mlir::cast(*this); - return NDArrayType::get(shape.value_or(getShape()), elementType, - t.getEnvironments(), t.getLayout()); -} - -imex::ndarray::NDArrayBase -imex::ndarray::NDArrayBase::cloneWithEnv(::mlir::Attribute env) const { - auto t = mlir::cast(*this); - ::mlir::SmallVector<::mlir::Attribute> envs(t.getEnvironments()); - envs.emplace_back(env); - return NDArrayType::get(t.getShape(), t.getElementType(), envs, - t.getLayout()); -} - -bool imex::ndarray::NDArrayBase::isValidElementType(Type type) { - return type.isIntOrIndexOrFloat(); -} - bool imex::ndarray::isUnitShape(const llvm::ArrayRef shp) { for (auto d : shp) { if (d != 1) @@ -135,12 +48,8 @@ bool imex::ndarray::isUnitShape(const llvm::ArrayRef shp) { return true; } -bool imex::ndarray::NDArrayType::hasUnitSize() const { - return isUnitShape(getShape()); -} - -bool imex::ndarray::NDArrayType::hasZeroSize() const { - for (auto d : getShape()) { +bool imex::ndarray::hasZeroSize(const llvm::ArrayRef shp) { + for (auto d : shp) { if (d == 0) return true; } @@ -150,5 +59,7 @@ bool imex::ndarray::NDArrayType::hasZeroSize() const { #include #define GET_TYPEDEF_CLASSES #include +#define GET_ATTRDEF_CLASSES +#include #define GET_OP_CLASSES #include diff --git a/lib/Dialect/NDArray/IR/PermuteDimsOp.cpp b/lib/Dialect/NDArray/IR/PermuteDimsOp.cpp deleted file mode 100644 index dd44a1872..000000000 --- a/lib/Dialect/NDArray/IR/PermuteDimsOp.cpp +++ /dev/null @@ -1,127 +0,0 @@ -//===- PermuteDimsOp.cpp - NDArray dialect ---------------------*- C++ -*-===// -// -// Copyright 2024 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements the PermuteDimsOp of the NDArray dialect. -/// Copied from NTensor. -/// -//===----------------------------------------------------------------------===// - -#include "mlir/IR/Value.h" -#include "mlir/Support/LogicalResult.h" -#include "llvm/ADT/SmallVector.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace { - -bool isPermutation(const ::llvm::ArrayRef &axes) { - auto sortedAxes = ::mlir::SmallVector(axes.begin(), axes.end()); - std::sort(sortedAxes.begin(), sortedAxes.end()); - for (int64_t i = 0; static_cast(i) < axes.size(); ++i) { - if (sortedAxes[i] != i) - return false; - } - return true; -} - -bool isSorted(const ::llvm::ArrayRef &axes) { - for (int64_t i = 0; static_cast(i) < axes.size(); ++i) { - if (axes[i] != i) - return false; - } - return true; -} - -/// Pattern to rewrite a permute_dims op with constant arguments. -/// Propagates constant shape args to op return type. -class PermuteDimsOpConstantArgumentFolder final - : public mlir::OpRewritePattern { -public: - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult - matchAndRewrite(imex::ndarray::PermuteDimsOp op, - mlir::PatternRewriter &rewriter) const override { - auto source = op.getSource(); - const auto axes = op.getAxes(); - auto rank = static_cast(axes.size()); - - auto sourceType = source.getType(); - if (rank != sourceType.getRank()) - return ::mlir::failure(); - - if (isSorted(axes)) { - rewriter.replaceOpWithNewOp(op, op.getType(), - source); - return ::mlir::success(); - } - - auto oldReturnType = op.getType(); - if (oldReturnType.hasStaticShape()) { - return ::mlir::failure(); - } - - const auto &oldShape = sourceType.getShape(); - ::mlir::SmallVector newShape(rank); - for (int64_t i = 0; i < rank; ++i) { - newShape[i] = oldShape[axes[i]]; - } - - auto newReturnType = - sourceType.cloneWith(newShape, sourceType.getElementType()); - if (newReturnType == oldReturnType) - return ::mlir::failure(); - - auto newOp = rewriter.create( - op->getLoc(), newReturnType, source, axes); - rewriter.replaceOpWithNewOp(op, oldReturnType, - newOp); - - return mlir::success(); - } -}; -} // namespace - -void imex::ndarray::PermuteDimsOp::getCanonicalizationPatterns( - mlir::RewritePatternSet &results, mlir::MLIRContext *context) { - // TODO: - // - convert [0,1,2] & [2,1,0] to no-op - results.add(context); -} - -mlir::LogicalResult imex::ndarray::PermuteDimsOp::verify() { - const auto axes = getAxes(); - - if (!isPermutation(axes)) { - return mlir::failure(); - } - - auto sourceType = getSource().getType(); - auto returnType = getType(); - - if (sourceType.hasStaticShape() && returnType.hasStaticShape()) { - auto sourceShape = sourceType.getShape(); - auto returnShape = returnType.getShape(); - for (size_t i = 0; i < axes.size(); ++i) { - if (sourceShape[axes[i]] != returnShape[i]) { - return mlir::failure(); - } - } - } - - return mlir::success(); -} diff --git a/lib/Dialect/NDArray/IR/SubviewOp.cpp b/lib/Dialect/NDArray/IR/SubviewOp.cpp index 05e2af816..cd060a06c 100644 --- a/lib/Dialect/NDArray/IR/SubviewOp.cpp +++ b/lib/Dialect/NDArray/IR/SubviewOp.cpp @@ -22,8 +22,8 @@ #include mlir::RankedTensorType imex::ndarray::SubviewOp::inferResultType( - mlir::RankedTensorType sourceType, - mlir::ArrayRef staticOffsets, mlir::ArrayRef staticSizes, + mlir::RankedTensorType sourceType, mlir::ArrayRef staticOffsets, + mlir::ArrayRef staticSizes, mlir::ArrayRef staticStrides) { unsigned rank = sourceType.getRank(); (void)rank; @@ -120,8 +120,8 @@ void imex::ndarray::SubviewOp::build( mlir::ArrayRef sizes, mlir::ArrayRef strides, mlir::ArrayRef attrs) { - build(b, result, mlir::RankedTensorType(), source, offsets, sizes, - strides, attrs); + build(b, result, mlir::RankedTensorType(), source, offsets, sizes, strides, + attrs); } // Build a SubViewOp with static entries and inferred result type. @@ -193,8 +193,8 @@ void imex::ndarray::SubviewOp::build( mlir::OpBuilder &b, mlir::OperationState &result, mlir::Value source, mlir::ValueRange offsets, mlir::ValueRange sizes, mlir::ValueRange strides, mlir::ArrayRef attrs) { - build(b, result, mlir::RankedTensorType(), source, offsets, sizes, - strides, attrs); + build(b, result, mlir::RankedTensorType(), source, offsets, sizes, strides, + attrs); } // Build a ExtractSliceOp with mixed static and dynamic entries and custom @@ -248,8 +248,8 @@ void imex::ndarray::ExtractSliceOp::build( mlir::OpBuilder &b, mlir::OperationState &result, mlir::Value source, mlir::ValueRange offsets, mlir::ValueRange sizes, mlir::ValueRange strides, mlir::ArrayRef attrs) { - build(b, result, mlir::RankedTensorType(), source, offsets, sizes, - strides, attrs); + build(b, result, mlir::RankedTensorType(), source, offsets, sizes, strides, + attrs); } // Copypasted from upstream tensor. @@ -455,8 +455,7 @@ class ExtractSliceFolder final return mlir::failure(); } - size_t rank = - mlir::cast<::mlir::RankedTensorType>(src.getType()).getRank(); + size_t rank = mlir::cast<::mlir::RankedTensorType>(src.getType()).getRank(); auto myOffs = op.getStaticOffsets(); auto mySizes = op.getStaticSizes(); auto myStrides = op.getStaticStrides(); @@ -502,7 +501,7 @@ class SubviewCastFolder final : public mlir::OpRewritePattern { if (!castOp) return mlir::failure(); - if (!imex::ndarray::canFoldIntoConsumerOp(castOp)) + if (!mlir::tensor::canFoldIntoConsumerOp(castOp)) return mlir::failure(); // Create folded extract. @@ -512,8 +511,8 @@ class SubviewCastFolder final : public mlir::OpRewritePattern { sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); if (newResult.getType() != sliceOp.getType()) - newResult = rewriter.create<::mlir::tensor::CastOp>(loc, sliceOp.getType(), - newResult); + newResult = rewriter.create<::mlir::tensor::CastOp>( + loc, sliceOp.getType(), newResult); rewriter.replaceOp(sliceOp, newResult); return mlir::success(); } diff --git a/lib/Dialect/NDArray/Transforms/AddGPURegions.cpp b/lib/Dialect/NDArray/Transforms/AddGPURegions.cpp index f03c240e1..9ef914bdb 100644 --- a/lib/Dialect/NDArray/Transforms/AddGPURegions.cpp +++ b/lib/Dialect/NDArray/Transforms/AddGPURegions.cpp @@ -12,7 +12,6 @@ /// //===----------------------------------------------------------------------===// -#include #include #include #include @@ -26,9 +25,7 @@ namespace imex { #define GEN_PASS_DEF_ADDGPUREGIONS #include -} // namespace imex -namespace imex { namespace { // Base-class for RewriterPatterns which handle recursion @@ -113,36 +110,17 @@ struct AddGPURegionsPass ::mlir::FrozenRewritePatternSet patterns; // It would be nicer to have a single rewrite-pattern which covers all // NDArrayOps - insertPatterns, - NDArrayOpRWP<::imex::ndarray::FromMemRefOp>, - NDArrayOpRWP<::imex::ndarray::DeleteOp>, - NDArrayOpRWP<::imex::ndarray::DimOp>, + insertPatterns, NDArrayOpRWP<::imex::ndarray::SubviewOp>, NDArrayOpRWP<::imex::ndarray::ExtractSliceOp>, NDArrayOpRWP<::imex::ndarray::InsertSliceOp>, NDArrayOpRWP<::imex::ndarray::ImmutableInsertSliceOp>, - NDArrayOpRWP<::imex::ndarray::LoadOp>, - NDArrayOpRWP<::imex::ndarray::CopyOp, false>, - NDArrayOpRWP<::imex::ndarray::CastOp>, + NDArrayOpRWP<::mlir::tensor::CastOp>, NDArrayOpRWP<::imex::ndarray::CastElemTypeOp>, NDArrayOpRWP<::imex::ndarray::LinSpaceOp>, NDArrayOpRWP<::imex::ndarray::CreateOp>, - NDArrayOpRWP<::imex::ndarray::ReshapeOp>, - NDArrayOpRWP<::imex::ndarray::EWBinOp>, - NDArrayOpRWP<::imex::ndarray::EWUnyOp>, - NDArrayOpRWP<::imex::ndarray::ReductionOp>, - NDArrayOpRWP<::imex::ndarray::PermuteDimsOp>, - NDArrayOpRWP<::imex::dist::InitDistArrayOp>, - NDArrayOpRWP<::imex::dist::LocalOffsetsOfOp>, - NDArrayOpRWP<::imex::dist::PartsOfOp>, - NDArrayOpRWP<::imex::dist::DefaultPartitionOp>, - NDArrayOpRWP<::imex::dist::LocalTargetOfSliceOp>, - NDArrayOpRWP<::imex::dist::LocalBoundingBoxOp>, - NDArrayOpRWP<::imex::dist::LocalCoreOp>, - NDArrayOpRWP<::imex::dist::RePartitionOp>, - NDArrayOpRWP<::imex::dist::SubviewOp>, - NDArrayOpRWP<::imex::dist::EWBinOp>, - NDArrayOpRWP<::imex::dist::EWUnyOp>>(getContext(), patterns); + NDArrayOpRWP<::imex::ndarray::ReshapeOp>>(getContext(), + patterns); (void)::mlir::applyPatternsAndFoldGreedily(this->getOperation(), patterns); } }; diff --git a/lib/Dialect/NDArray/Transforms/CMakeLists.txt b/lib/Dialect/NDArray/Transforms/CMakeLists.txt index ba1d8d026..928ebc7f8 100644 --- a/lib/Dialect/NDArray/Transforms/CMakeLists.txt +++ b/lib/Dialect/NDArray/Transforms/CMakeLists.txt @@ -1,6 +1,6 @@ add_imex_dialect_library(IMEXNDArrayTransforms - NDArrayDist.cpp AddGPURegions.cpp + CoalesceShardOps.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/imex/Dialect/NDArray @@ -12,5 +12,4 @@ add_imex_dialect_library(IMEXNDArrayTransforms MLIRIR MLIRPass IMEXNDArrayDialect - IMEXDistDialect ) diff --git a/lib/Dialect/Dist/Transforms/DistCoalesce.cpp b/lib/Dialect/NDArray/Transforms/CoalesceShardOps.cpp similarity index 62% rename from lib/Dialect/Dist/Transforms/DistCoalesce.cpp rename to lib/Dialect/NDArray/Transforms/CoalesceShardOps.cpp index fd3e94084..05e1d0d9a 100644 --- a/lib/Dialect/Dist/Transforms/DistCoalesce.cpp +++ b/lib/Dialect/NDArray/Transforms/CoalesceShardOps.cpp @@ -1,4 +1,4 @@ -//===- DistCoalesce.cpp - NDArrayToDist Transform -----*- C++ -*-===// +//===- CoalesceShardOps.cpp - CoalesceShardingOps Transform -----*- C++ -*-===// // // Copyright 2023 Intel Corporation // Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. @@ -8,12 +8,12 @@ //===----------------------------------------------------------------------===// /// /// \file -/// This file implements transforms of the Dist dialect. +/// This file implements a transform of Mesh and NDArray dialects. /// /// This pass tries to minimize the number of mesh::ShardOps. /// Instead of creating a new copy for each repartition, it tries to combine /// multiple RePartitionOps into one. For this, it computes the local bounding -/// box of several uses of repartitioned copies of the same base araay. It +/// box of several uses of repartitioned copies of the same base array. It /// replaces all matched RepartitionOps with one which provides the computed /// bounding box. Uses of the eliminated RePartitionOps get updated with th /// appropriate target part as originally used. Right now supported uses are @@ -33,9 +33,6 @@ /// //===----------------------------------------------------------------------===// -#include -#include -#include #include #include #include @@ -48,18 +45,16 @@ #include #include #include +#include #include #include #include namespace imex { -#define GEN_PASS_DEF_DISTCOALESCE -#include "imex/Dialect/Dist/Transforms/Passes.h.inc" -} // namespace imex +#define GEN_PASS_DEF_COALESCESHARDOPS +#include "imex/Dialect/NDArray/Transforms/Passes.h.inc" -namespace imex { -namespace dist { namespace { @@ -102,193 +97,10 @@ bool isElementwise(::mlir::Operation *op) { // ***** Pass infrastructure ***** // ******************************* -// Lowering dist dialect by no-ops -struct DistCoalescePass - : public imex::impl::DistCoalesceBase { - - DistCoalescePass() = default; - -#if 0 - // returns true if a Value is defined by any of the given operation types - template - static ::mlir::Operation *isDefByAnyOf(const ::mlir::Value &val) { - if (auto res = val.getDefiningOp()) - return res; - if constexpr (sizeof...(Ts)) - return isDefByAnyOf(val); - else if constexpr (!sizeof...(Ts)) - return nullptr; - } - - // returns true if an operation is of any of the given types - template - static bool isAnyOf(const ::mlir::Operation *op) { - if (::mlir::dyn_cast(op)) - return true; - if constexpr (sizeof...(Ts)) - return isAnyOf(op); - else if constexpr (!sizeof...(Ts)) - return false; - } - - static bool isCreator(::mlir::Operation *op) { - return op && - isAnyOf<::imex::ndarray::LinSpaceOp, ::imex::ndarray::CreateOp>(op); - } - - /// return true if given op comes from a EWOp and has another EWOp - /// as its single user. - bool is_temp(::mlir::mesh::ShardOp &op) { - if (!op->hasAttr("target") && op->hasOneUse() && - ::mlir::isa<::imex::dist::EWBinOp, ::imex::dist::EWUnyOp>( - *op->user_begin()) && - ::mlir::isa<::imex::dist::EWBinOp, ::imex::dist::EWUnyOp>( - op.getSrc().getDefiningOp())) { - return true; - } - return false; - } - - /// update a SubviewOp with a target part - /// create and return a new op if the SubviewOp has more than one use. - ::mlir::Operation *updateTargetPart(::mlir::IRRewriter &builder, - ::imex::dist::SubviewOp op, - const ::mlir::ValueRange &tOffs, - const ::mlir::ValueRange &tSizes) { - - // check if an existing target is the same as ours - auto offs = op.getTargetOffsets(); - auto szs = op.getTargetSizes(); - if (offs.size() > 0) { - assert(offs.size() == szs.size()); - ::mlir::SmallVector<::mlir::Operation *> toBeMoved; - for (size_t i = 0; i < offs.size(); ++i) { - if ((tOffs[i] != offs[i] || tSizes[i] != szs[i]) && !op->hasOneUse()) { - // existing but different target -> need a new repartition for our - // back-propagation - auto val = op.getSource(); - builder.setInsertionPointAfter(op); - - auto tmp = tOffs[0].getDefiningOp(); - auto &dom = this->getAnalysis<::mlir::DominanceInfo>(); - if (!dom.dominates(tmp, op)) { - toBeMoved.resize(0); - if (canMoveAfter(dom, tmp, op, toBeMoved)) { - ::mlir::Operation *curr = op; - for (auto dop : toBeMoved) { - dop->moveAfter(curr); - curr = dop; - } - builder.setInsertionPointAfter(curr); - } else { - assert(false && "Not implemented"); - } - } - assert(tOffs.size() == tSizes.size()); - auto dynPtType = cloneWithDynEnv( - mlir::cast<::imex::ndarray::NDArrayType>(val.getType())); - return builder.create<::imex::mesh::ShardOp>( - op->getLoc(), dynPtType, val, tOffs, tSizes); - } - } - // if same existing target -> nothing to be done - } else { - const int32_t rank = static_cast(tOffs.size()); - const int32_t svRank = op.getStaticSizes().size(); - const bool hasUnitSize = - mlir::cast<::imex::ndarray::NDArrayType>(op.getResult().getType()) - .hasUnitSize(); - - if (svRank == rank || hasUnitSize) { - if (hasUnitSize) { - // Here the subview can have a different rank than the target. - // The target can be empty (all dims have size zero) for example when - // the source insert_slice is unit-sized and happens on a different - // prank. In such cases we need to have all zeros in our target (of - // rank svRank). Otherwise the target size is 1. - mlir::OpBuilder::InsertionGuard guard(builder); - if (rank) { - builder.setInsertionPointAfter(tSizes[0].getDefiningOp()); - } else { - builder.setInsertionPoint(op); - } - - // first compute total size of target - auto loc = op->getLoc(); - auto zero = easyIdx(loc, builder, 0); - auto one = easyIdx(loc, builder, 1); - auto sz = one; - for (auto r = 0; r < rank; ++r) { - sz = sz * easyIdx(loc, builder, tSizes[r]); - } - // check if the target has total size 0 - sz = sz.eq(zero).select(zero, one); - op->insertOperands(op->getNumOperands(), - ::imex::ValVec(svRank, zero.get())); - op->insertOperands(op->getNumOperands(), - ::imex::ValVec(svRank, sz.get())); - } else { - // no existing target -> use ours - op->insertOperands(op->getNumOperands(), tOffs); - op->insertOperands(op->getNumOperands(), tSizes); - } - - const auto sSzsName = op.getOperandSegmentSizesAttrName(); - const auto oa = op->getAttrOfType<::mlir::DenseI32ArrayAttr>(sSzsName); - ::std::array sSzs{oa[0], oa[1], oa[2], - oa[3], svRank, svRank}; - op->setAttr(sSzsName, builder.getDenseI32ArrayAttr(sSzs)); - } else { - assert(false && "found dependent operation with different rank, needs " - "broadcasting support?"); - } - } - return nullptr; - } - - /// clone subviewops which are returned and mark them "final" - /// Needed to protect them from being "redirected" to a reparitioned copy - void backPropagateReturn(::mlir::IRRewriter &builder, - ::mlir::func::ReturnOp retOp) { - mlir::OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPoint(retOp); - bool altered = false; - ::imex::ValVec oprnds; - ::mlir::SmallVector<::mlir::Operation *> toErase; - for (auto val : retOp->getOperands()) { - if (isDist(val)) { - bool oneUse = true; - // "skip" casts and observe if this is a single-use chain - auto castOp = val.getDefiningOp<::mlir::UnrealizedConversionCastOp>(); - while (castOp && castOp.getInputs().size() == 1) { - if (!castOp->hasOneUse()) { - oneUse = false; - } - val = castOp.getInputs().front(); - castOp = val.getDefiningOp<::mlir::UnrealizedConversionCastOp>(); - } +struct CoalesceShardOpsPass + : public imex::impl::CoalesceShardOpsBase { - if (auto typedOp = val.getDefiningOp<::imex::dist::SubviewOp>()) { - auto iOp = builder.clone(*typedOp); - iOp->setAttr("final", builder.getUnitAttr()); - if (oneUse && typedOp->hasOneUse()) { - toErase.emplace_back(typedOp); - } - oprnds.emplace_back(iOp->getResult(0)); - altered = true; - continue; - } - } - oprnds.emplace_back(val); - } - if (altered) { - retOp->setOperands(oprnds); - for (auto op : toErase) { - op->erase(); - } - } - } -#endif + CoalesceShardOpsPass() = default; /// Follow def-chain of given Value until hitting a creation function /// or array-returning EWBinOp or EWUnyOp et al @@ -321,7 +133,6 @@ struct DistCoalescePass return nullptr; } - /// The actual back propagation of target parts /// if meeting a supported op, recursively gets defining ops and back @@ -417,28 +228,87 @@ struct DistCoalescePass assert(op->hasOneUse()); return ::mlir::dyn_cast<::mlir::mesh::ShardOp>(op); } -#if 0 - - /// compute target part for a given InsertSliceOp - ::imex::dist::TargetOfSliceOp computeTarget(::mlir::IRRewriter &builder, - ::imex::ndarray::InsertSliceOp op, - ::mlir::Value sharding) { - auto shardingOp = - ::mlir::cast<::mlir::mesh::ShardingOp>(sharding.getDefiningOp()); - auto sOffs = op.getStaticOffsets(); - auto sSizes = op.getStaticSizes(); - auto sStrides = op.getStaticStrides(); - assert(!(::mlir::ShapedType::isDynamicShape(sSizes) || - ::mlir::ShapedType::isDynamicShape(sOffs) || - ::mlir::ShapedType::isDynamicShape(sStrides)) || - (false && "SubviewOp must have dynamic offsets, sizes and strides")); - - auto src = getShardOpOfOperand(op.getDestination()).getSrc(); - return builder.create<::imex::dist::TargetOfSliceOp>( - op->getLoc(), src, sOffs, sSizes, sStrides, shardingOp.getMeshAttr(), - shardingOp.getSplitAxes()); + + template + static auto getBaseShardDimSize(T shard, T numShards, T extend) { + return extend / numShards + shard.sge(numShards - (extend % numShards)).select(1l, 0l); + }; + + static auto getBaseShardDimSize(int64_t shard, int64_t numShards, int64_t extend) { + return extend / numShards + (shard >= numShards - (extend % numShards) ? 1 : 0); + }; + + static ::mlir::SmallVector<::imex::EasyI64> extendHaloForSliceOp( + ::mlir::IRRewriter &rewriter, + mlir::Operation * op, + ::mlir::ArrayRef baseShape, + ::mlir::FlatSymbolRefAttr mesh, + ::mlir::mesh::MeshAxesArrayAttr splitAxes, + const ::mlir::SmallVector<::imex::EasyI64> &dynHaloSizes, + ::mlir::ArrayRef staticOffsets, + ::mlir::ArrayRef staticSizes, + ::mlir::ArrayRef staticStrides, + ::mlir::ArrayRef staticTargetOffsets) { + + const ::mlir::Location loc = op->getLoc(); + ::mlir::SymbolTableCollection symbolTable; + auto meshOp = ::mlir::mesh::getMesh(op, mesh, symbolTable); + assert(meshOp); + + // compute number of shards along split axes + // compute sharded dims extends (element count per sharded dim of base array) + ::mlir::SmallVector numShards, shardedDims; + for (auto dim = 0; dim<(int64_t)splitAxes.size(); ++dim) { + auto axes = splitAxes.getAxes()[dim]; + if(!axes.empty()) { + numShards.emplace_back(::mlir::mesh::collectiveProcessGroupSize(axes.asArrayRef(), meshOp)); + assert(!::mlir::ShapedType::isDynamic(numShards.back())); + shardedDims.emplace_back(dim); + } + } + + // init halo sizes either from input or to 0 + ::mlir::SmallVector<::imex::EasyI64> haloSizes = dynHaloSizes; + auto zero = easyI64(loc, rewriter, 0); + auto one = easyI64(loc, rewriter, 1); + if (haloSizes.empty()) { + haloSizes.resize(numShards.size()*2, zero); + } + assert(haloSizes.size() == numShards.size()*2); + + // iterate split axes and compute lower/upper halo bounds for each dim + int64_t curr = 0; + for (size_t dim=0; dim shardOps; - ::mlir::ValueRange halos; + ::mlir::SmallVector<::imex::EasyI64> halos; int numHalos = 0; for (auto axes : shardOp.getSharding().getDefiningOp<::mlir::mesh::ShardingOp>().getSplitAxes()) { if (!axes.empty()) { @@ -540,9 +410,8 @@ struct DistCoalescePass assert(svShardingOp); auto target = svShardingOp.getStaticShardedDimsOffsets(); assert(!::mlir::ShapedType::isDynamicShape(target) && "ShardOp of Subview must have static sharded dims sizes"); - auto mesh = svShardingOp.getMeshAttr().getValue(); builder.setInsertionPoint(shardOp); - halos = builder.create<::imex::dist::ExtendHaloForSliceOp>(subviewOp->getLoc(), haloResultTypes, baseShape.getShape(), mesh, svShardingOp.getSplitAxes(), halos, sOffs, sSizes, sStrides, target).getResult(); + halos = extendHaloForSliceOp(builder, subviewOp, baseShape.getShape(), svShardingOp.getMeshAttr(), svShardingOp.getSplitAxes(), halos, sOffs, sSizes, sStrides, target); shardOps.emplace_back(getShardOpOfOperand(subviewOp.getSource())); // subviewOps.emplace_back(subviewOp); } @@ -556,13 +425,17 @@ struct DistCoalescePass } // Update base sharding with halo sizes + ::imex::ValVec haloVals; + for (auto sz : halos) { + haloVals.emplace_back(sz.get()); + } auto orgSharding = shardOp.getSharding().getDefiningOp<::mlir::mesh::ShardingOp>(); builder.setInsertionPointAfter(shardOp); auto newSharding = builder.create<::mlir::mesh::ShardingOp>( shardOp->getLoc(), ::mlir::mesh::ShardingType::get(shardOp->getContext()), orgSharding.getMeshAttr(), orgSharding.getSplitAxesAttr(), orgSharding.getPartialAxesAttr(), orgSharding.getPartialTypeAttr(), ::mlir::DenseI64ArrayAttr::get(shardOp->getContext(), {}), ::mlir::ValueRange{}, - ::mlir::DenseI64ArrayAttr::get(shardOp->getContext(), ::mlir::SmallVector(halos.size(), ::mlir::ShapedType::kDynamic)), halos); + ::mlir::DenseI64ArrayAttr::get(shardOp->getContext(), ::mlir::SmallVector(haloVals.size(), ::mlir::ShapedType::kDynamic)), haloVals); auto newShardOp = builder.create<::mlir::mesh::ShardOp>( shardOp->getLoc(), shardOp, newSharding.getResult()); @@ -588,13 +461,11 @@ struct DistCoalescePass } }); } // runOnOperation -}; // DistCoalescePass +}; // CoalesceShardOpsPass } // namespace -} // namespace dist -/// Create a pass to eliminate Dist ops -std::unique_ptr<::mlir::Pass> createDistCoalescePass() { - return std::make_unique<::imex::dist::DistCoalescePass>(); +std::unique_ptr<::mlir::Pass> createCoalesceShardOpsPass() { + return std::make_unique<::imex::CoalesceShardOpsPass>(); } } // namespace imex diff --git a/lib/Dialect/NDArray/Transforms/NDArrayDist.cpp b/lib/Dialect/NDArray/Transforms/NDArrayDist.cpp deleted file mode 100644 index 4f7f893b4..000000000 --- a/lib/Dialect/NDArray/Transforms/NDArrayDist.cpp +++ /dev/null @@ -1,223 +0,0 @@ -//===- NDArrayDist.cpp - NDArrayToDist Transform ---------------*- C++ -*-===// -// -// Copyright 2023 Intel Corporation -// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file implements transform of the NDArray dialect to a combination of -/// NDArray and Dist dialects. -/// -/// Replace operations in NDArray if they have a shadow definition in Dist. -/// -//===----------------------------------------------------------------------===// - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace imex { -#define GEN_PASS_DEF_NDARRAYDIST -#include -} // namespace imex - -namespace imex { -namespace dist { - -namespace { - -// ******************************* -// ***** Individual patterns ***** -// ******************************* - -// match given operation if array operands and results are distributed -// and shared the same team -template -struct DistOpRWP : public ::mlir::OpRewritePattern { - using ::mlir::OpRewritePattern::OpRewritePattern; - - ::mlir::LogicalResult match(FROM op) const override { - DistEnvAttr dEnv; - if (op->getNumResults() > 0) { - auto outDisTTyp = mlir::dyn_cast<::imex::ndarray::NDArrayType>( - op->getResultTypes().front()); - if (!outDisTTyp || !isDist(outDisTTyp)) { - return ::mlir::failure(); - } else if (outDisTTyp) { - // to verify same teams are used with operands - dEnv = getDistEnv(outDisTTyp); - } - } - - for (auto r : op->getOperands()) { - auto opType = mlir::dyn_cast<::imex::ndarray::NDArrayType>(r.getType()); - if (opType) { - auto dEnv2 = getDistEnv(opType); - if (!dEnv2) { - // all dist operands and the return type must use the same team - return ::mlir::failure(); - } - if (!dEnv) { - dEnv = dEnv2; - } else { - assert(dEnv2.getTeam() == dEnv.getTeam()); - } - } - } - - return ::mlir::success(); - } -}; - -struct DistSubviewOpRWP : public DistOpRWP<::imex::ndarray::SubviewOp> { - using DistOpRWP<::imex::ndarray::SubviewOp>::DistOpRWP; - void rewrite(::imex::ndarray::SubviewOp op, - ::mlir::PatternRewriter &rewriter) const override { - auto empty = ::mlir::ValueRange{}; - rewriter.replaceOpWithNewOp<::imex::dist::SubviewOp>( - op, op.getType(), op.getSource(), op.getOffsets(), op.getSizes(), - op.getStrides(), op.getStaticOffsets(), op.getStaticSizes(), - op.getStaticStrides(), empty, empty); - } -}; - -struct DistEWUnyOpRWP : public DistOpRWP<::imex::ndarray::EWUnyOp> { - using DistOpRWP<::imex::ndarray::EWUnyOp>::DistOpRWP; - void rewrite(::imex::ndarray::EWUnyOp op, - ::mlir::PatternRewriter &rewriter) const override { - auto empty = ::mlir::ValueRange{}; - rewriter.replaceOpWithNewOp<::imex::dist::EWUnyOp>( - op, op.getType(), op.getOp(), op.getSrc(), empty, empty, empty); - } -}; - -/// 1. Compute local slice of dst (target part) -/// 2. Repartition input to computed target part -struct DistInsertSliceOpRWP : public DistOpRWP<::imex::ndarray::InsertSliceOp> { - using DistOpRWP<::imex::ndarray::InsertSliceOp>::DistOpRWP; - ::mlir::LogicalResult - matchAndRewrite(::imex::ndarray::InsertSliceOp op, - ::mlir::PatternRewriter &rewriter) const override { - auto src = op.getSource(); - auto srcType = mlir::cast<::imex::ndarray::NDArrayType>(src.getType()); - if (srcType.getRank() == 0 || - src.getDefiningOp<::imex::dist::RePartitionOp>() || - ::mlir::failed(match(op))) { - return ::mlir::failure(); - } - - auto loc = op.getLoc(); - auto dst = op.getDestination(); - auto slcOffs = - getMixedAsValues(loc, rewriter, op.getOffsets(), op.getStaticOffsets()); - auto slcSizes = - getMixedAsValues(loc, rewriter, op.getSizes(), op.getStaticSizes()); - auto slcStrides = - getMixedAsValues(loc, rewriter, op.getStrides(), op.getStaticStrides()); - - auto tSlice = rewriter.create<::imex::dist::LocalTargetOfSliceOp>( - loc, dst, slcOffs, slcSizes, slcStrides); - ::mlir::ValueRange tSlcOffs = tSlice.getTOffsets(); - ::mlir::ValueRange tSlcSizes = tSlice.getTSizes(); - - // Repartition source - auto nSrc = createRePartition(loc, rewriter, src, tSlcOffs, tSlcSizes); - - rewriter.modifyOpInPlace(op, [&]() { op.getSourceMutable().set(nSrc); }); - return ::mlir::success(); - } -}; - -/// Rewrite ::imex::ndarray::EWBinOp to get a distributed ewbinop -/// if operands are distributed. -/// Repartitions input arrays as needed. -struct DistEWBinOpRWP : public DistOpRWP<::imex::ndarray::EWBinOp> { - using DistOpRWP<::imex::ndarray::EWBinOp>::DistOpRWP; - - void rewrite(::imex::ndarray::EWBinOp op, - ::mlir::PatternRewriter &rewriter) const override { - - // get inputs and types - auto loc = op.getLoc(); - auto lhs = op.getLhs(); - auto rhs = op.getRhs(); - auto lhsDistTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(lhs.getType()); - auto rhsDistTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(rhs.getType()); - auto outDistTyp = - mlir::dyn_cast<::imex::ndarray::NDArrayType>(op.getType()); - - // Repartition if necessary - // FIXME: this breaks with dim-sizes==1, even if statically known - auto rbLhs = - rhs == lhs || lhsDistTyp.getRank() == 0 - ? lhs - : createRePartition(loc, rewriter, lhs); //, tOffs, tSizes); - auto rbRhs = - rhs == lhs - ? rbLhs - : (rhsDistTyp.getRank() == 0 - ? rhs - : createRePartition(loc, rewriter, rhs)); //, tOffs, tSizes); - - auto empty = ::mlir::ValueRange{}; - rewriter.replaceOpWithNewOp<::imex::dist::EWBinOp>( - op, outDistTyp, op.getOp(), rbLhs, rbRhs, empty, empty, empty); - } -}; - -// ******************************* -// ***** Pass infrastructure ***** -// ******************************* - -// Lowering dist dialect by no-ops -struct NDArrayDistPass : public ::imex::impl::NDArrayDistBase { - - NDArrayDistPass() = default; - - void runOnOperation() override { - - ::mlir::FrozenRewritePatternSet patterns; - insertPatterns(getContext(), patterns); - (void)::mlir::applyPatternsAndFoldGreedily(this->getOperation(), patterns); - } -}; - -} // namespace -} // namespace dist - -/// Populate the given list with patterns that eliminate Dist ops -void populateNDArrayDistPatterns(::mlir::LLVMTypeConverter &converter, - ::mlir::RewritePatternSet &patterns) { - assert(false); -} - -/// Create a pass to eliminate Dist ops -std::unique_ptr<::mlir::Pass> createNDArrayDistPass() { - return std::make_unique<::imex::dist::NDArrayDistPass>(); -} - -} // namespace imex diff --git a/test/Conversion/DistToStandard/BoundingBox.mlir b/test/Conversion/DistToStandard/BoundingBox.mlir deleted file mode 100644 index a5a730d94..000000000 --- a/test/Conversion/DistToStandard/BoundingBox.mlir +++ /dev/null @@ -1,46 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-dist-to-standard -canonicalize %s -verify-diagnostics -o -| FileCheck %s - -func.func @test_bb() -> (index, index, index, index, index, index, index, index, index, index, index, index, index, index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %c7 = arith.constant 7 : index - %c8 = arith.constant 8 : index - - %o1, %s1 = dist.local_bounding_box false[%c0] [%c8] [%c1] [%c4] [%c4] : index, index - %o2, %s2 = dist.local_bounding_box false[%c1] [%c7] [%c1] [%c4] [%c4] : index, index - %o3, %s3 = dist.local_bounding_box false[%c0] [%c8] [%c1] [%c4] [%c0] : index, index - %o4, %s4 = dist.local_bounding_box false[%c1] [%c7] [%c1] [%c0] [%c0] : index, index - %o5, %s5 = dist.local_bounding_box false[%c1] [%c7] [%c1] [%c2] [%c2] : index, index - %o6, %s6 = dist.local_bounding_box false[%c0] [%c0] [%c1] [%c4] [%c4] : index, index - %o7, %s7 = dist.local_bounding_box false[%c1] [%c3] [%c1] [%c0] [%c4] : index, index - - return %o1, %s1, %o2, %s2, %o3, %s3, %o4, %s4, %o5, %s5, %o6, %s6, %o7, %s7 : index, index, index, index, index, index, index, index, index, index, index, index, index, index -} -// CHECK-LABEL: func.func @test_bb -// CHECK: return %c4, %c4, %c5, %c3, %c4, %c0, %c1, %c0, %c3, %c2, %c4, %c0, %c1, %c3 - -// ----- -func.func @test_bb2() -> (index, index, index, index, index, index, index, index, index, index, index, index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %c5 = arith.constant 5 : index - %c7 = arith.constant 7 : index - %c8 = arith.constant 8 : index - - %o1, %s1 = dist.local_bounding_box false[%c1] [%c7] [%c1] [%c4] [%c4] bboffs %c4 bb_sizes %c4 : index, index - %o2, %s2 = dist.local_bounding_box false[%c0] [%c8] [%c1] [%c4] [%c0] bboffs %o1 bb_sizes %s1 : index, index - %o3, %s3 = dist.local_bounding_box false[%c1] [%c3] [%c1] [%c0] [%c4] bboffs %c0 bb_sizes %c5 : index, index - %o4, %s4 = dist.local_bounding_box false[%c1] [%c7] [%c1] [%c2] [%c2] bboffs %c1 bb_sizes %c0 : index, index - %o5, %s5 = dist.local_bounding_box false[%c0] [%c1] [%c1] [%c0] [%c1] bboffs %o4 bb_sizes %s4 : index, index - %o6, %s6 = dist.local_bounding_box false[%c7] [%c8] [%c1] [%c0] [%c1] bboffs %o5 bb_sizes %s5 : index, index - - return %o1, %s1, %o2, %s2, %o3, %s3, %o4, %s4, %o5, %s5, %o6, %s6 : index, index, index, index, index, index, index, index, index, index, index, index -} -// CHECK-LABEL: func.func @test_bb2 -// CHECK: return %c4, %c4, %c4, %c4, %c0, %c5, %c3, %c2, %c0, %c5, %c0, %c8 diff --git a/test/Conversion/DistToStandard/DefaultPartition.mlir b/test/Conversion/DistToStandard/DefaultPartition.mlir deleted file mode 100644 index 0c9f25448..000000000 --- a/test/Conversion/DistToStandard/DefaultPartition.mlir +++ /dev/null @@ -1,36 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-dist-to-standard -canonicalize %s -verify-diagnostics -o -| FileCheck %s - -func.func @test_def_part() -> (index, index, index, index, index, index, index, index, index, index, index, index) { - %NP = arith.constant 8 : index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c4 = arith.constant 4 : index - %c7 = arith.constant 7 : index - %c32 = arith.constant 32 : index - %c33 = arith.constant 33 : index - %c31 = arith.constant 31 : index - %o0, %s0 = "dist.default_partition"(%NP, %c7, %c31) : (index, index, index) -> (index, index) - %o1, %s1 = "dist.default_partition"(%NP, %c7, %c32) : (index, index, index) -> (index, index) - %o2, %s2 = "dist.default_partition"(%NP, %c7, %c33) : (index, index, index) -> (index, index) - %o3, %s3 = "dist.default_partition"(%NP, %c2, %c31) : (index, index, index) -> (index, index) - %o4, %s4 = "dist.default_partition"(%NP, %c1, %c31) : (index, index, index) -> (index, index) - %o5, %s5 = "dist.default_partition"(%NP, %c0, %c31) : (index, index, index) -> (index, index) - return %o0, %s0, %o1, %s1, %o2, %s2, %o3, %s3, %o4, %s4, %o5, %s5 : index, index, index, index, index, index, index, index, index, index, index, index -} -// CHECK-LABEL: func.func @test_def_part() -// CHECK: return %c27, %c4, %c28, %c4, %c28, %c5, %c7, %c4, %c3, %c4, %c0, %c3 - -// ----- -func.func @test_def_part2() -> (index, index, index, index, index, index, index, index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c7 = arith.constant 7 : index - %c8 = arith.constant 8 : index - %o0:2, %s0:2 = "dist.default_partition"(%c2, %c1, %c8, %c8) : (index, index, index, index) -> (index, index, index, index) - %o1:2, %s1:2 = "dist.default_partition"(%c2, %c0, %c1, %c7) : (index, index, index, index) -> (index, index, index, index) - return %o0#0, %o0#1, %s0#0, %s0#1, %o1#0, %o1#1, %s1#0, %s1#1 : index, index, index, index, index, index, index, index -} -// CHECK-LABEL: func.func @test_def_part2() -// CHECK: return %c4, %c0, %c4, %c8, %c0, %c0, %c0, %c7 diff --git a/test/Conversion/DistToStandard/DistToStandard.mlir b/test/Conversion/DistToStandard/DistToStandard.mlir deleted file mode 100644 index 293254bbf..000000000 --- a/test/Conversion/DistToStandard/DistToStandard.mlir +++ /dev/null @@ -1,254 +0,0 @@ -// RUN: imex-opt --split-input-file --convert-dist-to-standard %s -verify-diagnostics -o -| FileCheck %s - -func.func @test_copy(%arg0: !ndarray.ndarray<2xi64>, %arg1: !ndarray.ndarray<6xi64>, %arg2: !ndarray.ndarray<0xi64>) -> (!ndarray.ndarray<2xi64>, !ndarray.ndarray<6xi64>, !ndarray.ndarray<0xi64>) { - %c1 = arith.constant 1 : index - %a = dist.init_dist_array l_offset %c1 parts %arg0, %arg1, %arg2 : index, !ndarray.ndarray<2xi64>, !ndarray.ndarray<6xi64>, !ndarray.ndarray<0xi64> to !ndarray.ndarray<33xi64, #dist.dist_env> - %1 = ndarray.copy %a : !ndarray.ndarray<33xi64, #dist.dist_env> -> !ndarray.ndarray<33xi64, #dist.dist_env> - %2 = ndarray.copy %1 : !ndarray.ndarray<33xi64, #dist.dist_env> -> !ndarray.ndarray<33xi64, #dist.dist_env, #region.gpu_env> - %4 = ndarray.copy %2 : !ndarray.ndarray<33xi64, #dist.dist_env, #region.gpu_env> -> !ndarray.ndarray<33xi64, #dist.dist_env> - %20, %21, %22 = "dist.parts_of"(%4) : (!ndarray.ndarray<33xi64, #dist.dist_env>) -> (!ndarray.ndarray<2xi64>, !ndarray.ndarray<6xi64>, !ndarray.ndarray<0xi64>) - return %20, %21, %22 : !ndarray.ndarray<2xi64>, !ndarray.ndarray<6xi64>, !ndarray.ndarray<0xi64> -} -// CHECK-LABEL: func.func @test_copy -// CHECK-SAME: [[arg0:%.*]]: !ndarray.ndarray<2xi64>, [[arg1:%.*]]: !ndarray.ndarray<6xi64>, [[arg2:%.*]]: !ndarray.ndarray<0xi64> -// CHECK: [[v1:%.*]] = ndarray.copy [[arg0]] : !ndarray.ndarray<2xi64> -> !ndarray.ndarray<2xi64> -// CHECK-NEXT: [[v2:%.*]] = ndarray.copy [[arg1]] : !ndarray.ndarray<6xi64> -> !ndarray.ndarray<6xi64> -// CHECK-NEXT: [[v3:%.*]] = ndarray.copy [[arg2]] : !ndarray.ndarray<0xi64> -> !ndarray.ndarray<0xi64> -// CHECK-NEXT: [[v4:%.*]] = ndarray.copy [[v1]] : !ndarray.ndarray<2xi64> -> !ndarray.ndarray<2xi64, #region.gpu_env> -// CHECK-NEXT: [[v5:%.*]] = ndarray.copy [[v2]] : !ndarray.ndarray<6xi64> -> !ndarray.ndarray<6xi64, #region.gpu_env> -// CHECK-NEXT: [[v6:%.*]] = ndarray.copy [[v3]] : !ndarray.ndarray<0xi64> -> !ndarray.ndarray<0xi64, #region.gpu_env> -// CHECK-NEXT: [[v7:%.*]] = ndarray.copy [[v4]] : !ndarray.ndarray<2xi64, #region.gpu_env> -> !ndarray.ndarray<2xi64> -// CHECK-NEXT: [[v8:%.*]] = ndarray.copy [[v5]] : !ndarray.ndarray<6xi64, #region.gpu_env> -> !ndarray.ndarray<6xi64> -// CHECK-NEXT: [[v9:%.*]] = ndarray.copy [[v6]] : !ndarray.ndarray<0xi64, #region.gpu_env> -> !ndarray.ndarray<0xi64> - -// ----- -func.func @test_delete(%arg0: !ndarray.ndarray<2xi64>, %arg1: !ndarray.ndarray<6xi64>, %arg2: !ndarray.ndarray<0xi64>) { - %c1 = arith.constant 1 : index - %a = dist.init_dist_array l_offset %c1 parts %arg0, %arg1, %arg2 : index, !ndarray.ndarray<2xi64>, !ndarray.ndarray<6xi64>, !ndarray.ndarray<0xi64> to !ndarray.ndarray<33xi64, #dist.dist_env> - ndarray.delete %a : !ndarray.ndarray<33xi64, #dist.dist_env> - return -} -// CHECK-LABEL: func.func @test_delete -// CHECK: ndarray.delete -// CHECK-SAME: : !ndarray.ndarray<2xi64> -// CHECK-NEXT: ndarray.delete -// CHECK-SAME: : !ndarray.ndarray<6xi64> -// CHECK-NEXT: ndarray.delete -// CHECK-SAME: : !ndarray.ndarray<0xi64> - -// ----- -func.func @test_init_dist_array(%arg0: !ndarray.ndarray<2xi64>, %arg1: !ndarray.ndarray<6xi64>, %arg2: !ndarray.ndarray<0xi64>) -> (!ndarray.ndarray<2xi64>, !ndarray.ndarray<6xi64>, !ndarray.ndarray<0xi64>) { - %c1 = arith.constant 1 : index - %a = dist.init_dist_array l_offset %c1 parts %arg0, %arg1, %arg2 : index, !ndarray.ndarray<2xi64>, !ndarray.ndarray<6xi64>, !ndarray.ndarray<0xi64> to !ndarray.ndarray<33xi64, #dist.dist_env> - %20, %21, %22 = "dist.parts_of"(%a) : (!ndarray.ndarray<33xi64, #dist.dist_env>) -> (!ndarray.ndarray<2xi64>, !ndarray.ndarray<6xi64>, !ndarray.ndarray<0xi64>) - return %20, %21, %22 : !ndarray.ndarray<2xi64>, !ndarray.ndarray<6xi64>, !ndarray.ndarray<0xi64> -} -// CHECK-LABEL: func.func @test_init_dist_array -// CHECK-SAME: [[arg0:%.*]]: !ndarray.ndarray<2xi64>, [[arg1:%.*]]: !ndarray.ndarray<6xi64>, [[arg2:%.*]]: !ndarray.ndarray<0xi64> -// CHECK: return [[arg0]], [[arg1]], [[arg2]] : !ndarray.ndarray<2xi64>, !ndarray.ndarray<6xi64>, !ndarray.ndarray<0xi64> - -// ----- -func.func @test_local_partition(%np : index, %prank: index, %shape: index) -> (index, index) { - %0, %1 = "dist.default_partition"(%np, %prank, %shape) {rank = 1 : i64} : (index, index, index) -> (index, index) - return %0, %1 : index, index -} -// CHECK-LABEL: func.func @test_local_partition(%arg0: index, %arg1: index, %arg2: index) -> (index, index) { -// CHECK: arith.remsi -// CHECK: arith.divsi -// CHECK-DAG: arith.addi -// CHECK-DAG: arith.cmpi -// CHECK-DAG: arith.select -// CHECK-DAG: arith.addi -// CHECK-DAG: arith.subi -// CHECK-DAG: arith.subi -// CHECK-DAG: arith.maxsi -// CHECK-DAG: arith.muli -// CHECK: arith.addi -// CHECK: arith.maxsi - - -// ----- -func.func @test_local_target_of_slice(%arg0: !ndarray.ndarray, %arg1: !ndarray.ndarray, %arg2: !ndarray.ndarray, %c0 : index, %c3 : index) -> (index, index) { - %c1 = arith.constant 1 : index - %a = dist.init_dist_array l_offset %c1 parts %arg0, %arg1, %arg2 : index, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray to !ndarray.ndarray> - %l_offsets, %l_sizes = dist.local_target_of_slice %a[%c0] [%c3] [%c3] : !ndarray.ndarray> to index, index - return %l_offsets, %l_sizes : index, index -} -// CHECK-LABEL: func.func @test_local_target_of_slice -// CHECK arith.constant -// CHECK memref.load -// CHECK arith.constant -// CHECK arith.constant -// CHECK ndarray.dim -// CHECK arith.constant -// CHECK ndarray.dim -// CHECK arith.addi -// CHECK arith.constant -// CHECK ndarray.dim -// CHECK arith.addi -// CHECK arith.constant -// CHECK arith.constant -// CHECK arith.muli -// CHECK arith.addi -// CHECK arith.addi -// CHECK arith.maxsi -// CHECK arith.subi -// CHECK arith.addi -// CHECK arith.subi -// CHECK arith.divsi -// CHECK arith.muli -// CHECK arith.addi -// CHECK arith.minsi -// CHECK arith.addi -// CHECK arith.subi -// CHECK arith.maxsi -// CHECK arith.divsi -// CHECK arith.subi -// CHECK arith.divsi -// CHECK arith.minsi -// CHECK arith.constant -// CHECK: return - -// ----- -func.func @test_copy_reshape() -> () { - %i1 = arith.constant 1 : i32 - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - - %0 = ndarray.create %c3, %c4 value %i1 {dtype = 4 : i8} : (index, index, i32) -> !ndarray.ndarray<3x4xi32, #dist.dist_env> - %1 = ndarray.reshape %0 %c2, %c3, %c2 {copy = true} - : !ndarray.ndarray<3x4xi32, #dist.dist_env> - -> !ndarray.ndarray<2x3x2xi32, #dist.dist_env> - return -} -// CHECK-LABEL: @test_copy_reshape -// CHECK: "distruntime.team_size"() <{team = 22 : i64}> : () -> index -// CHECK: "distruntime.team_member"() <{team = 22 : i64}> : () -> index -// CHECK: [[handle:%.*]], [[nlArray:%.*]] = distruntime.copy_reshape -// CHECK: "distruntime.wait"([[handle]]) : (!distruntime.asynchandle) -> () - -// ----- -func.func @test_repartition(%arg0: !ndarray.ndarray, %arg1: !ndarray.ndarray, %arg2: !ndarray.ndarray) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) { - %c0 = arith.constant 0 : index - %a = dist.init_dist_array l_offset %c0, %c0 parts %arg0, %arg1, %arg2 : index, index, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray to !ndarray.ndarray<10x12xi64, #dist.dist_env> - %4 = dist.repartition %a : !ndarray.ndarray<10x12xi64, #dist.dist_env> to !ndarray.ndarray<10x12xi64, #dist.dist_env> - %20, %21, %22 = "dist.parts_of"(%4) : (!ndarray.ndarray<10x12xi64, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - return %20, %21, %22 : !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray -} -// CHECK-LABEL: @test_repartition -// CHECK: ndarray.dim -// CHECK: ndarray.dim -// CHECK: ndarray.dim -// CHECK: ndarray.dim -// CHECK: distruntime.get_halo - -// ----- -func.func @test_local_core(%arg0: !ndarray.ndarray, %arg1: !ndarray.ndarray, %arg2: !ndarray.ndarray, %arg3: index) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 0 : index - %c6 = arith.constant 6 : index - %c8 = arith.constant 8 : index - %a = dist.init_dist_array l_offset %arg3 parts %arg0, %arg1, %arg2 : index, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray to !ndarray.ndarray<16xi64, #dist.dist_env> - - %resultOffsets, %resultSizes = dist.local_core %a toffs %c0 tsizes %c6 soffs %c0 ssizes %c8 sstrides %c1 : !ndarray.ndarray<16xi64, #dist.dist_env> to index, index - %resultOffsets_6, %resultSizes_7 = dist.local_core %a toffs %c0 tsizes %c6 soffs %c1 ssizes %c8 sstrides %c1 coffs %resultOffsets csizes %resultSizes : !ndarray.ndarray<16xi64, #dist.dist_env> to index, index - return -} -// CHECK-LABEL: func.func @test_local_core -// CHECK-SAME: [[arg0:%.*]]: !ndarray.ndarray, [[arg1:%.*]]: !ndarray.ndarray, [[arg2:%.*]]: !ndarray.ndarray, [[arg3:%.*]]: index -// CHECK: [[vc0:%.*]] = arith.constant 0 : index -// CHECK: [[vc0_0:%.*]] = arith.constant 0 : index -// CHECK: [[vc6:%.*]] = arith.constant 6 : index -// CHECK: [[vc8:%.*]] = arith.constant 8 : index -// CHECK: [[vc0_1:%.*]] = arith.constant 0 : index -// CHECK: [[vdim:%.*]] = ndarray.dim [[arg1]] [[vc0_1]] : !ndarray.ndarray -> index -// CHECK: [[vc0_2:%.*]] = arith.constant 0 : index -// CHECK: [[vdim_3:%.*]] = ndarray.dim [[arg0]] [[vc0_2]] : !ndarray.ndarray -> index -// CHECK: [[v0:%.*]] = arith.addi [[arg3]], [[vdim_3]] : index -// CHECK: [[vc0_4:%.*]] = arith.constant 0 : index -// CHECK: [[vc1:%.*]] = arith.constant 1 : index -// CHECK: [[v1:%.*]] = arith.addi [[arg3]], [[vdim]] : index -// CHECK: [[v2:%.*]] = arith.maxsi [[arg3]], [[vc0]] : index -// CHECK: [[vcm1:%.*]] = arith.constant -1 : index -// CHECK: [[v3:%.*]] = arith.addi [[v2]], [[vcm1]] : index -// CHECK: [[v4:%.*]] = arith.divsi [[v3]], [[vc0_0]] : index -// CHECK: [[v5:%.*]] = arith.minsi [[v1]], [[vc0]] : index -// CHECK: [[v6:%.*]] = arith.addi [[v5]], [[vcm1]] : index -// CHECK: [[v7:%.*]] = arith.maxsi [[v6]], [[vc0_4]] : index -// CHECK: [[v8:%.*]] = arith.divsi [[v7]], [[vc0_0]] : index -// CHECK: [[vc0_5:%.*]] = arith.constant 0 : index -// CHECK: [[v9:%.*]] = arith.divsi [[vc0_5]], [[vc0_0]] : index -// CHECK: [[v10:%.*]] = arith.minsi [[v9]], [[vc8]] : index -// CHECK: [[vc0_6:%.*]] = arith.constant 0 : index -// CHECK: [[v11:%.*]] = arith.subi [[vc0]], [[v10]] : index -// CHECK: [[v12:%.*]] = arith.subi [[vc0_6]], [[v11]] : index -// CHECK: [[v13:%.*]] = arith.maxsi [[v12]], [[vc0_6]] : index -// CHECK: [[v14:%.*]] = arith.subi [[v8]], [[v11]] : index -// CHECK: [[v15:%.*]] = arith.subi [[v14]], [[v13]] : index -// CHECK: [[v16:%.*]] = arith.subi [[vc6]], [[v13]] : index -// CHECK: [[vc6_7:%.*]] = arith.constant 6 : index -// CHECK: [[v17:%.*]] = arith.subi [[vc6_7]], [[v13]] : index -// CHECK: [[v18:%.*]] = arith.minsi [[v15]], [[v16]] : index -// CHECK: [[v19:%.*]] = arith.minsi [[v17]], [[v18]] : index -// CHECK: [[vc0_8:%.*]] = arith.constant 0 : index -// CHECK: [[vdim_9:%.*]] = ndarray.dim [[arg1]] [[vc0_8]] : !ndarray.ndarray -> index -// CHECK: [[vc0_10:%.*]] = arith.constant 0 : index -// CHECK: [[vdim_11:%.*]] = ndarray.dim [[arg0]] [[vc0_10]] : !ndarray.ndarray -> index -// CHECK: [[v20:%.*]] = arith.addi [[arg3]], [[vdim_11]] : index -// CHECK: [[vc0_12:%.*]] = arith.constant 0 : index -// CHECK: [[vc1_13:%.*]] = arith.constant 1 : index -// CHECK: [[v21:%.*]] = arith.addi [[arg3]], [[vdim_9]] : index -// CHECK: [[v22:%.*]] = arith.maxsi [[arg3]], [[vc0_0]] : index -// CHECK: [[vcm1_14:%.*]] = arith.constant -1 : index -// CHECK: [[v23:%.*]] = arith.addi [[v22]], [[vcm1_14]] : index -// CHECK: [[v24:%.*]] = arith.divsi [[v23]], [[vc0_0]] : index -// CHECK: [[v25:%.*]] = arith.minsi [[v21]], [[vc0_0]] : index -// CHECK: [[v26:%.*]] = arith.addi [[v25]], [[vcm1_14]] : index -// CHECK: [[v27:%.*]] = arith.maxsi [[v26]], [[vc0_12]] : index -// CHECK: [[v28:%.*]] = arith.divsi [[v27]], [[vc0_0]] : index -// CHECK: [[vc0_15:%.*]] = arith.constant 0 : index -// CHECK: [[v29:%.*]] = arith.divsi [[vc0_15]], [[vc0_0]] : index -// CHECK: [[v30:%.*]] = arith.minsi [[v29]], [[vc8]] : index -// CHECK: [[vc0_16:%.*]] = arith.constant 0 : index -// CHECK: [[v31:%.*]] = arith.subi [[vc0]], [[v30]] : index -// CHECK: [[v32:%.*]] = arith.subi [[vc0_16]], [[v31]] : index -// CHECK: [[v33:%.*]] = arith.maxsi [[v13]], [[v32]] : index -// CHECK: [[v34:%.*]] = arith.subi [[v28]], [[v31]] : index -// CHECK: [[v35:%.*]] = arith.subi [[v34]], [[v33]] : index -// CHECK: [[v36:%.*]] = arith.subi [[vc6]], [[v33]] : index -// CHECK: [[v37:%.*]] = arith.addi [[v13]], [[v19]] : index -// CHECK: [[v38:%.*]] = arith.subi [[v37]], [[v33]] : index -// CHECK: [[v39:%.*]] = arith.minsi [[v35]], [[v36]] : index -// CHECK: [[v40:%.*]] = arith.minsi [[v38]], [[v39]] : index -// CHECK: return - -// ----- -func.func @test_cast_elemtype(%arg0: !ndarray.ndarray, %arg1: !ndarray.ndarray, %arg2: !ndarray.ndarray, %arg3: index) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) { - %a = dist.init_dist_array l_offset %arg3 parts %arg0, %arg1, %arg2 : index, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray to !ndarray.ndarray<16xi64, #dist.dist_env> - %4 = ndarray.cast_elemtype %a : !ndarray.ndarray<16xi64, #dist.dist_env> to !ndarray.ndarray<16xi32, #dist.dist_env> - %20, %21, %22 = "dist.parts_of"(%4) : (!ndarray.ndarray<16xi32, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - return %20, %21, %22 : !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray -} -// CHECK-LABEL: @test_cast_elemtype -// CHECK: [[V1:%.*]] = ndarray.cast_elemtype %arg0 -// CHECK-NEXT: [[V2:%.*]] = ndarray.cast_elemtype %arg1 -// CHECK-NEXT: [[V3:%.*]] = ndarray.cast_elemtype %arg2 -// CHECK: return [[V1]], [[V2]], [[V3]] - -// ----- -func.func @test_copy_permute() -> () { - %i1 = arith.constant 1 : i32 - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %src = ndarray.create %c3, %c4 value %i1 {dtype = 4 : i8} : (index, index, i32) -> !ndarray.ndarray<3x4xi32, #dist.dist_env> - %dist = ndarray.permute_dims %src [1, 0] - : !ndarray.ndarray<3x4xi32, #dist.dist_env> - -> !ndarray.ndarray<4x3xi32, #dist.dist_env> - return -} -// CHECK-LABEL: @test_copy_permute -// CHECK: "distruntime.team_size"() <{team = 22 : i64}> : () -> index -// CHECK: "distruntime.team_member"() <{team = 22 : i64}> : () -> index -// CHECK: [[handle:%.*]], [[nlArray:%.*]] = distruntime.copy_permute -// CHECK: "distruntime.wait"([[handle]]) : (!distruntime.asynchandle) -> () diff --git a/test/Conversion/DistToStandard/Subview.mlir b/test/Conversion/DistToStandard/Subview.mlir deleted file mode 100644 index 4307ead94..000000000 --- a/test/Conversion/DistToStandard/Subview.mlir +++ /dev/null @@ -1,125 +0,0 @@ -// RUN: imex-opt --convert-dist-to-standard -canonicalize %s -verify-diagnostics -o -| FileCheck %s - -func.func @test1(%arg0: !ndarray.ndarray<0xf32>, %arg1: !ndarray.ndarray<4xf32>, %arg2: !ndarray.ndarray<0xf32>) - -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) { - %c2 = arith.constant 2 : index - %c4 = arith.constant 4 : index - %c6 = arith.constant 6 : index - %3 = dist.init_dist_array l_offset %c6 parts %arg0, %arg1, %arg2 : index, !ndarray.ndarray<0xf32>, !ndarray.ndarray<4xf32>, !ndarray.ndarray<0xf32> to !ndarray.ndarray<16xf32, #dist.dist_env> - - %4 = dist.subview %3[2] [8] [1] toffs %c4 tsizes %c4: !ndarray.ndarray<16xf32, #dist.dist_env> to !ndarray.ndarray<8xf32, #dist.dist_env> - %5 = dist.subview %3[2] [8] [1] toffs %c6 tsizes %c2: !ndarray.ndarray<16xf32, #dist.dist_env> to !ndarray.ndarray<8xf32, #dist.dist_env> - %6 = dist.subview %3[2] [8] [1] toffs %c2 tsizes %c4: !ndarray.ndarray<16xf32, #dist.dist_env> to !ndarray.ndarray<8xf32, #dist.dist_env> - - %20, %21, %22 = "dist.parts_of"(%4) : (!ndarray.ndarray<8xf32, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - %30, %31, %32 = "dist.parts_of"(%5) : (!ndarray.ndarray<8xf32, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - %40, %41, %42 = "dist.parts_of"(%6) : (!ndarray.ndarray<8xf32, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - - return %20, %21, %22, %30, %31, %32, %40, %41, %42 : !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray,!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray -} -// CHECK-LABEL: func.func @test1( -// CHECK-SAME: [[v0:%.*]]: !ndarray.ndarray<0xf32>, [[v1:%.*]]: !ndarray.ndarray<4xf32>, [[v2:%.*]]: !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v0]][0] [0] [1] : !ndarray.ndarray<0xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v1]][0] [4] [1] : !ndarray.ndarray<4xf32> to !ndarray.ndarray<4xf32> -// CHECK: ndarray.subview [[v2]][0] [0] [1] : !ndarray.ndarray<0xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v0]][0] [0] [1] : !ndarray.ndarray<0xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v1]][2] [2] [1] : !ndarray.ndarray<4xf32> to !ndarray.ndarray<2xf32> -// CHECK: ndarray.subview [[v2]][0] [0] [1] : !ndarray.ndarray<0xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v0]][0] [0] [1] : !ndarray.ndarray<0xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v1]][0] [2] [1] : !ndarray.ndarray<4xf32> to !ndarray.ndarray<2xf32> -// CHECK: ndarray.subview [[v2]][0] [0] [1] : !ndarray.ndarray<0xf32> to !ndarray.ndarray<0xf32> - -// ----- -func.func private @printMemrefInd(memref<*xindex>) - func.func @test2(%arg0: !ndarray.ndarray<1xf32>, %arg1: !ndarray.ndarray<4xf32>, %arg2: !ndarray.ndarray<0xf32>) - -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) { - %c2 = arith.constant 2 : index - %c4 = arith.constant 4 : index - %c5 = arith.constant 5 : index - %c6 = arith.constant 6 : index - %3 = dist.init_dist_array l_offset %c5 parts %arg0, %arg1, %arg2 : index, !ndarray.ndarray<1xf32>, !ndarray.ndarray<4xf32>, !ndarray.ndarray<0xf32> to !ndarray.ndarray<16xf32, #dist.dist_env> - - %4 = dist.subview %3[2] [8] [1] toffs %c4 tsizes %c4: !ndarray.ndarray<16xf32, #dist.dist_env> to !ndarray.ndarray<8xf32, #dist.dist_env> - %5 = dist.subview %3[2] [8] [1] toffs %c6 tsizes %c2: !ndarray.ndarray<16xf32, #dist.dist_env> to !ndarray.ndarray<8xf32, #dist.dist_env> - %6 = dist.subview %3[2] [8] [1] toffs %c2 tsizes %c4: !ndarray.ndarray<16xf32, #dist.dist_env> to !ndarray.ndarray<8xf32, #dist.dist_env> - - %20, %21, %22 = "dist.parts_of"(%4) : (!ndarray.ndarray<8xf32, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - %30, %31, %32 = "dist.parts_of"(%5) : (!ndarray.ndarray<8xf32, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - %40, %41, %42 = "dist.parts_of"(%6) : (!ndarray.ndarray<8xf32, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - - return %20, %21, %22, %30, %31, %32, %40, %41, %42 : !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray,!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray -} -// CHECK-LABEL: func.func @test2( -// CHECK-SAME: [[v0:%.*]]: !ndarray.ndarray<1xf32>, [[v1:%.*]]: !ndarray.ndarray<4xf32>, [[v2:%.*]]: !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v0]][1] [0] [1] : !ndarray.ndarray<1xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v1]][0] [4] [1] : !ndarray.ndarray<4xf32> to !ndarray.ndarray<4xf32> -// CHECK: ndarray.subview [[v2]][0] [0] [1] : !ndarray.ndarray<0xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v0]][1] [0] [1] : !ndarray.ndarray<1xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v1]][2] [2] [1] : !ndarray.ndarray<4xf32> to !ndarray.ndarray<2xf32> -// CHECK: ndarray.subview [[v2]][0] [0] [1] : !ndarray.ndarray<0xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v0]][0] [1] [1] : !ndarray.ndarray<1xf32> to !ndarray.ndarray<1xf32> -// CHECK: ndarray.subview [[v1]][0] [2] [1] : !ndarray.ndarray<4xf32> to !ndarray.ndarray<2xf32> -// CHECK: ndarray.subview [[v2]][0] [0] [1] : !ndarray.ndarray<0xf32> to !ndarray.ndarray<0xf32> - -// ----- -func.func @test3(%arg0: !ndarray.ndarray<1xf32>, %arg1: !ndarray.ndarray<4xf32>, %arg2: !ndarray.ndarray<1xf32>) - -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) { - %c0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - %c4 = arith.constant 4 : index - %c6 = arith.constant 6 : index - %3 = dist.init_dist_array l_offset %c2 parts %arg0, %arg1, %arg2 : index, !ndarray.ndarray<1xf32>, !ndarray.ndarray<4xf32>, !ndarray.ndarray<1xf32> to !ndarray.ndarray<16xf32, #dist.dist_env> - - %4 = dist.subview %3[2] [8] [1] toffs %c4 tsizes %c4: !ndarray.ndarray<16xf32, #dist.dist_env> to !ndarray.ndarray<8xf32, #dist.dist_env> - %5 = dist.subview %3[2] [8] [1] toffs %c6 tsizes %c2: !ndarray.ndarray<16xf32, #dist.dist_env> to !ndarray.ndarray<8xf32, #dist.dist_env> - %6 = dist.subview %3[2] [8] [1] toffs %c0 tsizes %c6: !ndarray.ndarray<16xf32, #dist.dist_env> to !ndarray.ndarray<8xf32, #dist.dist_env> - - %20, %21, %22 = "dist.parts_of"(%4) : (!ndarray.ndarray<8xf32, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - %30, %31, %32 = "dist.parts_of"(%5) : (!ndarray.ndarray<8xf32, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - %40, %41, %42 = "dist.parts_of"(%6) : (!ndarray.ndarray<8xf32, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - - return %20, %21, %22, %30, %31, %32, %40, %41, %42 : !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray,!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray -} -// CHECK-LABEL: func.func @test3( -// CHECK-SAME: [[v0:%.*]]: !ndarray.ndarray<1xf32>, [[v1:%.*]]: !ndarray.ndarray<4xf32>, [[v2:%.*]]: !ndarray.ndarray<1xf32> -// CHECK: ndarray.subview [[v0]][1] [0] [1] : !ndarray.ndarray<1xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v1]][3] [1] [1] : !ndarray.ndarray<4xf32> to !ndarray.ndarray<1xf32> -// CHECK: ndarray.subview [[v2]][0] [1] [1] : !ndarray.ndarray<1xf32> to !ndarray.ndarray<1xf32> -// CHECK: ndarray.subview [[v0]][1] [0] [1] : !ndarray.ndarray<1xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v1]][4] [0] [1] : !ndarray.ndarray<4xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v2]][1] [0] [1] : !ndarray.ndarray<1xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v0]][0] [1] [1] : !ndarray.ndarray<1xf32> to !ndarray.ndarray<1xf32> -// CHECK: ndarray.subview [[v1]][0] [4] [1] : !ndarray.ndarray<4xf32> to !ndarray.ndarray<4xf32> -// CHECK: ndarray.subview [[v2]][0] [1] [1] : !ndarray.ndarray<1xf32> to !ndarray.ndarray<1xf32> - -// ----- -func.func @test4(%arg0: !ndarray.ndarray<0xf32>, %arg1: !ndarray.ndarray<4xf32>, %arg2: !ndarray.ndarray<1xf32>) - -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) { - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %c5 = arith.constant 5 : index - %c6 = arith.constant 6 : index - %3 = dist.init_dist_array l_offset %c3 parts %arg0, %arg1, %arg2 : index, !ndarray.ndarray<0xf32>, !ndarray.ndarray<4xf32>, !ndarray.ndarray<1xf32> to !ndarray.ndarray<16xf32, #dist.dist_env> - - %4 = dist.subview %3[2] [8] [1] toffs %c4 tsizes %c4: !ndarray.ndarray<16xf32, #dist.dist_env> to !ndarray.ndarray<8xf32, #dist.dist_env> - %5 = dist.subview %3[2] [8] [1] toffs %c5 tsizes %c4: !ndarray.ndarray<16xf32, #dist.dist_env> to !ndarray.ndarray<8xf32, #dist.dist_env> - %6 = dist.subview %3[2] [8] [1] toffs %c2 tsizes %c4: !ndarray.ndarray<16xf32, #dist.dist_env> to !ndarray.ndarray<8xf32, #dist.dist_env> - - %20, %21, %22 = "dist.parts_of"(%4) : (!ndarray.ndarray<8xf32, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - %30, %31, %32 = "dist.parts_of"(%5) : (!ndarray.ndarray<8xf32, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - %40, %41, %42 = "dist.parts_of"(%6) : (!ndarray.ndarray<8xf32, #dist.dist_env>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - - return %20, %21, %22, %30, %31, %32, %40, %41, %42 : !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray,!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray -} -// CHECK-LABEL: func.func @test4( -// CHECK-SAME: [[v0:%.*]]: !ndarray.ndarray<0xf32>, [[v1:%.*]]: !ndarray.ndarray<4xf32>, [[v2:%.*]]: !ndarray.ndarray<1xf32> -// CHECK: ndarray.subview [[v0]][0] [0] [1] : !ndarray.ndarray<0xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v1]][3] [1] [1] : !ndarray.ndarray<4xf32> to !ndarray.ndarray<1xf32> -// CHECK: ndarray.subview [[v2]][0] [1] [1] : !ndarray.ndarray<1xf32> to !ndarray.ndarray<1xf32> -// CHECK: ndarray.subview [[v0]][0] [0] [1] : !ndarray.ndarray<0xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v1]][4] [0] [1] : !ndarray.ndarray<4xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v2]][0] [1] [1] : !ndarray.ndarray<1xf32> to !ndarray.ndarray<1xf32> -// CHECK: ndarray.subview [[v0]][0] [0] [1] : !ndarray.ndarray<0xf32> to !ndarray.ndarray<0xf32> -// CHECK: ndarray.subview [[v1]][1] [3] [1] : !ndarray.ndarray<4xf32> to !ndarray.ndarray<3xf32> -// CHECK: ndarray.subview [[v2]][0] [1] [1] : !ndarray.ndarray<1xf32> to !ndarray.ndarray<1xf32> diff --git a/test/Dialect/Dist/IR/DistOps.mlir b/test/Dialect/Dist/IR/DistOps.mlir deleted file mode 100644 index 349914a7d..000000000 --- a/test/Dialect/Dist/IR/DistOps.mlir +++ /dev/null @@ -1,50 +0,0 @@ -// RUN: imex-opt %s | FileCheck %s -// Verify the printed output can be parsed. -// RUN: imex-opt %s | imex-opt | FileCheck %s -// RUN: imex-opt -mlir-print-op-generic %s | imex-opt | FileCheck %s - -// ----- -func.func @test_init_dist_array(%pt: !ndarray.ndarray, %loffs: index) -> !ndarray.ndarray> { - %1 = dist.init_dist_array l_offset %loffs parts %pt, %pt, %pt : index, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray to !ndarray.ndarray> - return %1 : !ndarray.ndarray> -} -// CHECK-LABEL: func.func @test_init_dist_array(%arg0: !ndarray.ndarray, %arg1: index) -> !ndarray.ndarray> { -// CHECK-NEXT: dist.init_dist_array - -// ----- -func.func @test_extract_from_dist(%arg0: !ndarray.ndarray>) { - %20, %21, %22 = "dist.parts_of"(%arg0) : (!ndarray.ndarray>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) - %3 = "dist.local_offsets_of"(%arg0) : (!ndarray.ndarray>) -> index - return -} -// CHECK-LABEL: @test_extract_from_dist(%arg0: !ndarray.ndarray>) { -// CHECK-NEXT: :3 = "dist.parts_of"(%arg0) : (!ndarray.ndarray>) -> (!ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray) -// CHECK-NEXT: "dist.local_offsets_of"(%arg0) : (!ndarray.ndarray>) -> index - -// ----- -func.func @test_default_partition(%np : index, %prank: index, %shape: index) -> (index, index) { - %0, %1 = "dist.default_partition"(%np, %prank, %shape) {rank = 1 : i64} : (index, index, index) -> (index, index) - return %0, %1 : index, index -} -// CHECK-LABEL: func.func @test_default_partition(%arg0: index, %arg1: index, %arg2: index) -> (index, index) { -// CHECK-NEXT: "dist.default_partition"(%arg0, %arg1, %arg2) {rank = 1 : i64} : (index, index, index) -> (index, index) - -// ----- -func.func @test_local_target_of_slice(%arg0: !ndarray.ndarray>) -> (index, index) { - %c0 = arith.constant 0 : index - %c3 = arith.constant 3 : index - %l_offsets, %l_sizes = dist.local_target_of_slice %arg0[%c0] [%c3] [%c3] : !ndarray.ndarray> to index, index - return %l_offsets, %l_sizes : index, index -} -// CHECK-LABEL: @test_local_target_of_slice -// CHECK: [[C1:%.*]], [[C2:%.*]] = dist.local_target_of_slice -// CHECK: return [[C1]], [[C2]] - -// ----- -func.func @test_repartition(%arg0: !ndarray.ndarray>) -> (!ndarray.ndarray>) { - %0 = dist.repartition %arg0 : !ndarray.ndarray> to !ndarray.ndarray> - return %0 : !ndarray.ndarray> -} -// CHECK-LABEL: @test_repartition -// CHECK: [[C1:%.*]] = dist.repartition -// CHECK: return [[C1]] diff --git a/test/Dialect/Dist/IR/lit.local.cfg b/test/Dialect/Dist/IR/lit.local.cfg deleted file mode 100644 index 2328eb821..000000000 --- a/test/Dialect/Dist/IR/lit.local.cfg +++ /dev/null @@ -1,7 +0,0 @@ -if sys.platform == "win32": - local_excludes = ['DistOps.mlir'] -else: - local_excludes = [] - -if(not config.imex_enable_excluded_tests): - config.excludes.update(local_excludes) diff --git a/test/Dialect/Dist/Transforms/DistCoalesce.mlir b/test/Dialect/Dist/Transforms/DistCoalesce.mlir deleted file mode 100644 index 9c0d0c34a..000000000 --- a/test/Dialect/Dist/Transforms/DistCoalesce.mlir +++ /dev/null @@ -1,70 +0,0 @@ -// RUN: imex-opt --split-input-file --dist-coalesce %s -verify-diagnostics -o -| FileCheck %s - -module { - func.func @test_coalesce1() -> (!ndarray.ndarray>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c5 = arith.constant 5 : index - %c10 = arith.constant 10 : index - %c30 = arith.constant 30 : index - %0 = ndarray.linspace %c0 %c10 %c10 false : (index, index, index) -> !ndarray.ndarray - %1 = dist.init_dist_array l_offset %c5 parts %0, %0, %0 : index, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray to !ndarray.ndarray> - %2 = dist.init_dist_array l_offset %c5 parts %0, %0, %0 : index, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray to !ndarray.ndarray> - %3 = dist.repartition %1 : !ndarray.ndarray> to !ndarray.ndarray> - %4 = dist.repartition %2 : !ndarray.ndarray> to !ndarray.ndarray> - %5 = "dist.ewbin"(%3, %4) {op = 0 : i32} : (!ndarray.ndarray>, !ndarray.ndarray>) -> !ndarray.ndarray> - %6 = dist.repartition %5 : !ndarray.ndarray> to !ndarray.ndarray> - %7 = dist.repartition %1 : !ndarray.ndarray> to !ndarray.ndarray> - %8 = "dist.ewbin"(%6, %7) {op = 0 : i32} : (!ndarray.ndarray>, !ndarray.ndarray>) -> !ndarray.ndarray> - return %8 : !ndarray.ndarray> - } -} -// CHECK-LABEL: func.func @test_coalesce1() -// CHECK: dist.repartition -// CHECK-NEXT: dist.repartition -// CHECK-NEXT: dist.ewbin -// CHECK-NEXT: dist.repartition -// CHECK-NEXT: dist.ewbin -// CHECK-NEXT: return - -// ----- -module { - func.func @test_coalesce2() -> (!ndarray.ndarray>) { - %c0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - %c1 = arith.constant 1 : index - %c5 = arith.constant 5 : index - %c10 = arith.constant 10 : index - %c30 = arith.constant 30 : index - %0 = ndarray.linspace %c0 %c10 %c10 false : (index, index, index) -> !ndarray.ndarray - %1 = dist.init_dist_array l_offset %c5 parts %0, %0, %0 : index, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray to !ndarray.ndarray> - %v1 = dist.subview %1[%c0] [%c5] [%c2] : !ndarray.ndarray> to !ndarray.ndarray> - %v2 = dist.subview %1[%c5] [%c5] [%c1] : !ndarray.ndarray> to !ndarray.ndarray> - %3 = dist.repartition %v1 : !ndarray.ndarray> to !ndarray.ndarray> - %4 = dist.repartition %v2 : !ndarray.ndarray> to !ndarray.ndarray> - %5 = "dist.ewbin"(%3, %4) {op = 0 : i32} : (!ndarray.ndarray>, !ndarray.ndarray>) -> !ndarray.ndarray> - %v3 = dist.subview %1[%c1] [%c5] [%c1] : !ndarray.ndarray> to !ndarray.ndarray> - %6 = dist.repartition %5 : !ndarray.ndarray> to !ndarray.ndarray> - %7 = dist.repartition %v3 : !ndarray.ndarray> to !ndarray.ndarray> - %8 = "dist.ewbin"(%6, %7) {op = 0 : i32} : (!ndarray.ndarray>, !ndarray.ndarray>) -> !ndarray.ndarray> - %t_offsets, %t_sizes = dist.local_target_of_slice %1[%c1] [%c5] [%c2] : !ndarray.ndarray> to index, index - %10 = dist.repartition %8 loffs %t_offsets lsizes %t_sizes : !ndarray.ndarray>, index, index to !ndarray.ndarray> - ndarray.insert_slice %10 into %1[%c1] [%c5] [%c2] : !ndarray.ndarray> into !ndarray.ndarray> - return %1 : !ndarray.ndarray> - } -} -// CHECK-LABEL: func.func @test_coalesce2() -// CHECK: dist.init_dist_array -// CHECK: distruntime.team_size -// CHECK: distruntime.team_member -// CHECK: dist.local_target_of_slice -// CHECK-NEXT: dist.local_bounding_box -// CHECK-NEXT: dist.local_bounding_box -// CHECK-NEXT: dist.local_bounding_box -// CHECK-NEXT: dist.repartition -// CHECK-NEXT: dist.subview -// CHECK-NEXT: dist.subview -// CHECK-NEXT: dist.subview -// CHECK: dist.ewbin -// CHECK: dist.ewbin -// CHECK: ndarray.insert_slice diff --git a/test/Dialect/Dist/Transforms/DistInferEWCores.mlir b/test/Dialect/Dist/Transforms/DistInferEWCores.mlir deleted file mode 100644 index bc801d5a2..000000000 --- a/test/Dialect/Dist/Transforms/DistInferEWCores.mlir +++ /dev/null @@ -1,84 +0,0 @@ -// RUN: imex-opt --split-input-file --dist-infer-elementwise-cores -canonicalize %s -verify-diagnostics -o -| FileCheck %s - -module { - func.func @test_infer() -> !ndarray.ndarray<16xi64, #dist.dist_env> attributes {llvm.emit_c_interface} { - %c0 = arith.constant 0 : index - %c10 = arith.constant 10 : index - %c1_i64 = arith.constant 1 : i64 - %c3 = arith.constant 3 : index - %c2 = arith.constant 2 : index - %c2_i64 = arith.constant 2 : i64 - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c0_i64 = arith.constant 0 : i64 - %0 = "distruntime.team_size"() {team = 94098061490592 : i64} : () -> index - %1 = "distruntime.team_member"() {team = 94098061490592 : i64} : () -> index - %l_offsets, %l_shape = "dist.default_partition"(%0, %1, %c16) : (index, index, index) -> (index, index) - %2 = ndarray.create %l_shape value %c0_i64 {dtype = 2 : i8} : (index, i64) -> !ndarray.ndarray - %3 = ndarray.create %c0 {dtype = 2 : i8} : (index) -> !ndarray.ndarray<0xi64> - %4 = ndarray.cast %3 : !ndarray.ndarray<0xi64> to !ndarray.ndarray - %5 = dist.init_dist_array l_offset %l_offsets parts %4, %2, %4 : index, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray to !ndarray.ndarray<16xi64, #dist.dist_env> - %t_offsets, %t_sizes = dist.local_target_of_slice %5[%c3] [%c10] [%c1] : !ndarray.ndarray<16xi64, #dist.dist_env> to index, index - %result_offsets, %result_sizes = dist.local_bounding_box false [%c0] [%c10] [%c1] [%t_offsets] [%t_sizes] : index, index - %result_offsets_0, %result_sizes_1 = dist.local_bounding_box false [%c1] [%c10] [%c1] [%t_offsets] [%t_sizes] bboffs %result_offsets bb_sizes %result_sizes : index, index - %result_offsets_2, %result_sizes_3 = dist.local_bounding_box false [%c2] [%c10] [%c1] [%t_offsets] [%t_sizes] bboffs %result_offsets_0 bb_sizes %result_sizes_1 : index, index - %result_offsets_4, %result_sizes_5 = dist.local_bounding_box false [%c3] [%c10] [%c1] [%t_offsets] [%t_sizes] bboffs %result_offsets_2 bb_sizes %result_sizes_3 : index, index - %6 = dist.repartition %5 loffs %result_offsets_4 lsizes %result_sizes_5 : !ndarray.ndarray<16xi64, #dist.dist_env>, index, index to !ndarray.ndarray<16xi64, #dist.dist_env> - %7 = dist.subview %6[%c0] [10] [%c1] toffs %t_offsets tsizes %t_sizes : !ndarray.ndarray<16xi64, #dist.dist_env> to !ndarray.ndarray<10xi64, #dist.dist_env> - %8 = dist.subview %6[%c1] [10] [%c1] toffs %t_offsets tsizes %t_sizes : !ndarray.ndarray<16xi64, #dist.dist_env> to !ndarray.ndarray<10xi64, #dist.dist_env> - %9 = dist.subview %6[%c2] [10] [%c1] toffs %t_offsets tsizes %t_sizes : !ndarray.ndarray<16xi64, #dist.dist_env> to !ndarray.ndarray<10xi64, #dist.dist_env> - %10 = dist.subview %6[%c3] [10] [%c1] toffs %t_offsets tsizes %t_sizes : !ndarray.ndarray<16xi64, #dist.dist_env> to !ndarray.ndarray<10xi64, #dist.dist_env> - %11 = ndarray.create value %c2_i64 {dtype = 2 : i8} : (i64) -> !ndarray.ndarray - %12 = dist.init_dist_array parts %11 : !ndarray.ndarray to !ndarray.ndarray> - %13 = "dist.ewbin"(%7, %12) {op = 21 : i32} : (!ndarray.ndarray<10xi64, #dist.dist_env>, !ndarray.ndarray>) -> !ndarray.ndarray<10xi64, #dist.dist_env> - %14 = "dist.ewbin"(%13, %8) {op = 0 : i32} : (!ndarray.ndarray<10xi64, #dist.dist_env>, !ndarray.ndarray<10xi64, #dist.dist_env>) -> !ndarray.ndarray<10xi64, #dist.dist_env> - %15 = "dist.ewbin"(%14, %9) {op = 0 : i32} : (!ndarray.ndarray<10xi64, #dist.dist_env>, !ndarray.ndarray<10xi64, #dist.dist_env>) -> !ndarray.ndarray<10xi64, #dist.dist_env> - %16 = ndarray.create value %c1_i64 {dtype = 2 : i8} : (i64) -> !ndarray.ndarray - %17 = dist.init_dist_array parts %16 : !ndarray.ndarray to !ndarray.ndarray> - %18 = "dist.ewbin"(%10, %17) {op = 21 : i32} : (!ndarray.ndarray<10xi64, #dist.dist_env>, !ndarray.ndarray>) -> !ndarray.ndarray<10xi64, #dist.dist_env> - %19 = "dist.ewbin"(%15, %18) {op = 0 : i32} : (!ndarray.ndarray<10xi64, #dist.dist_env>, !ndarray.ndarray<10xi64, #dist.dist_env>) -> !ndarray.ndarray<10xi64, #dist.dist_env> - ndarray.insert_slice %19 into %5[%c3] [%c10] [%c1] : !ndarray.ndarray<10xi64, #dist.dist_env> into !ndarray.ndarray<16xi64, #dist.dist_env> - %20 = "ndarray.cast"(%5) : (!ndarray.ndarray<16xi64, #dist.dist_env>) -> !ndarray.ndarray<16xi64, #dist.dist_env> - return %20 : !ndarray.ndarray<16xi64, #dist.dist_env> - } -} -// CHECK-LABEL: func.func @test_infer() -// CHECK: [[vc10:%.*]] = arith.constant 10 : index -// CHECK: [[vc0:%.*]] = arith.constant 0 : index -// CHECK: [[vc1_i64:%.*]] = arith.constant 1 : i64 -// CHECK: [[vc3:%.*]] = arith.constant 3 : index -// CHECK: [[vc2:%.*]] = arith.constant 2 : index -// CHECK: [[vc2_i64:%.*]] = arith.constant 2 : i64 -// CHECK: [[vc1:%.*]] = arith.constant 1 : index -// CHECK: [[vc16:%.*]] = arith.constant 16 : index -// CHECK: [[vc0_i64:%.*]] = arith.constant 0 : i64 -// CHECK: [[v0:%.*]] = "distruntime.team_size"() <{team = 94098061490592 : i64}> : () -> index -// CHECK: [[v1:%.*]] = "distruntime.team_member"() <{team = 94098061490592 : i64}> : () -> index -// CHECK: [[vl_offsets:%.*]], [[vl_shape:%.*]] = "dist.default_partition"([[v0]], [[v1]], [[vc16]]) : (index, index, index) -> (index, index) -// CHECK: [[v2:%.*]] = ndarray.create [[vl_shape]] value [[vc0_i64]] {dtype = 2 : i8} : (index, i64) -> !ndarray.ndarray -// CHECK: [[v3:%.*]] = ndarray.create [[vc0]] {dtype = 2 : i8} : (index) -> !ndarray.ndarray<0xi64> -// CHECK: [[v4:%.*]] = ndarray.cast [[v3]] : !ndarray.ndarray<0xi64> to !ndarray.ndarray -// CHECK: [[v5:%.*]] = dist.init_dist_array l_offset [[vl_offsets]] parts [[v4]], [[v2]], [[v4]] : index, !ndarray.ndarray, !ndarray.ndarray, !ndarray.ndarray to !ndarray.ndarray<16xi64, #dist.dist_env> -// CHECK: [[vt_offsets:%.*]], [[vt_sizes:%.*]] = dist.local_target_of_slice [[v5]][[[vc3]]] [[[vc10]]] [[[vc1]]] : !ndarray.ndarray<16xi64, #dist.dist_env> to index, index -// CHECK: [[vresultOffsets:%.*]], [[vresultSizes:%.*]] = dist.local_core [[v5]] toffs [[vt_offsets]] tsizes [[vt_sizes]] soffs %c{{[0-9]}} ssizes [[vc10]] sstrides [[vc1]] : !ndarray.ndarray<16xi64, #dist.dist_env> to index, index -// CHECK: [[vresultOffsets_0:%.*]], [[vresultSizes_1:%.*]] = dist.local_core [[v5]] toffs [[vt_offsets]] tsizes [[vt_sizes]] soffs %c{{[0-9]}} ssizes [[vc10]] sstrides [[vc1]] coffs [[vresultOffsets]] csizes [[vresultSizes]] : !ndarray.ndarray<16xi64, #dist.dist_env> to index, index -// CHECK: [[vresultOffsets_2:%.*]], [[vresultSizes_3:%.*]] = dist.local_core [[v5]] toffs [[vt_offsets]] tsizes [[vt_sizes]] soffs %c{{[0-9]}} ssizes [[vc10]] sstrides [[vc1]] coffs [[vresultOffsets_0]] csizes [[vresultSizes_1]] : !ndarray.ndarray<16xi64, #dist.dist_env> to index, index -// CHECK: [[vresultOffsets_4:%.*]], [[vresultSizes_5:%.*]] = dist.local_core [[v5]] toffs [[vt_offsets]] tsizes [[vt_sizes]] soffs %c{{[0-9]}} ssizes [[vc10]] sstrides [[vc1]] coffs [[vresultOffsets_2]] csizes [[vresultSizes_3]] : !ndarray.ndarray<16xi64, #dist.dist_env> to index, index -// CHECK: [[vresult_offsets:%.*]], [[vresult_sizes:%.*]] = dist.local_bounding_box false[[[vc0]]] [[[vc10]]] [[[vc1]]] [[[vt_offsets]]] [[[vt_sizes]]] : index, index -// CHECK: [[vresult_offsets_6:%.*]], [[vresult_sizes_7:%.*]] = dist.local_bounding_box false[[[vc1]]] [[[vc10]]] [[[vc1]]] [[[vt_offsets]]] [[[vt_sizes]]] bboffs [[vresult_offsets]] bb_sizes [[vresult_sizes]] : index, index -// CHECK: [[vresult_offsets_8:%.*]], [[vresult_sizes_9:%.*]] = dist.local_bounding_box false[[[vc2]]] [[[vc10]]] [[[vc1]]] [[[vt_offsets]]] [[[vt_sizes]]] bboffs [[vresult_offsets_6]] bb_sizes [[vresult_sizes_7]] : index, index -// CHECK: [[vresult_offsets_10:%.*]], [[vresult_sizes_11:%.*]] = dist.local_bounding_box false[[[vc3]]] [[[vc10]]] [[[vc1]]] [[[vt_offsets]]] [[[vt_sizes]]] bboffs [[vresult_offsets_8]] bb_sizes [[vresult_sizes_9]] : index, index -// CHECK: [[v6:%.*]] = dist.repartition [[v5]] loffs [[vresult_offsets_10]] lsizes [[vresult_sizes_11]] : !ndarray.ndarray<16xi64, #dist.dist_env>, index, index to !ndarray.ndarray<16xi64, #dist.dist_env> -// CHECK: [[v7:%.*]] = dist.subview [[v6]][[[vc0]]] [10] [[[vc1]]] toffs [[vt_offsets]] tsizes [[vt_sizes]] : !ndarray.ndarray<16xi64, #dist.dist_env> to !ndarray.ndarray<10xi64, #dist.dist_env> -// CHECK: [[v8:%.*]] = dist.subview [[v6]][[[vc1]]] [10] [[[vc1]]] toffs [[vt_offsets]] tsizes [[vt_sizes]] : !ndarray.ndarray<16xi64, #dist.dist_env> to !ndarray.ndarray<10xi64, #dist.dist_env> -// CHECK: [[v9:%.*]] = dist.subview [[v6]][[[vc2]]] [10] [[[vc1]]] toffs [[vt_offsets]] tsizes [[vt_sizes]] : !ndarray.ndarray<16xi64, #dist.dist_env> to !ndarray.ndarray<10xi64, #dist.dist_env> -// CHECK: [[v10:%.*]] = dist.subview [[v6]][[[vc3]]] [10] [[[vc1]]] toffs [[vt_offsets]] tsizes [[vt_sizes]] : !ndarray.ndarray<16xi64, #dist.dist_env> to !ndarray.ndarray<10xi64, #dist.dist_env> -// CHECK: [[v11:%.*]] = ndarray.create value [[vc2_i64]] {dtype = 2 : i8} : (i64) -> !ndarray.ndarray -// CHECK: [[v12:%.*]] = dist.init_dist_array parts [[v11]] : !ndarray.ndarray to !ndarray.ndarray> -// CHECK: [[v13:%.*]] = "dist.ewbin"([[v7]], [[v12]], [[vresultOffsets_4]], [[vresultSizes_5]], [[vt_offsets]]) <{op = 21 : i32}> : (!ndarray.ndarray<10xi64, #dist.dist_env>, !ndarray.ndarray>, index, index, index) -> !ndarray.ndarray<10xi64, #dist.dist_env> -// CHECK: [[v14:%.*]] = "dist.ewbin"([[v13]], [[v8]], [[vresultOffsets_4]], [[vresultSizes_5]], [[vt_offsets]]) <{op = 0 : i32}> : (!ndarray.ndarray<10xi64, #dist.dist_env>, !ndarray.ndarray<10xi64, #dist.dist_env>, index, index, index) -> !ndarray.ndarray<10xi64, #dist.dist_env> -// CHECK: [[v15:%.*]] = "dist.ewbin"([[v14]], [[v9]], [[vresultOffsets_4]], [[vresultSizes_5]], [[vt_offsets]]) <{op = 0 : i32}> : (!ndarray.ndarray<10xi64, #dist.dist_env>, !ndarray.ndarray<10xi64, #dist.dist_env>, index, index, index) -> !ndarray.ndarray<10xi64, #dist.dist_env> -// CHECK: [[v16:%.*]] = ndarray.create value [[vc1_i64]] {dtype = 2 : i8} : (i64) -> !ndarray.ndarray -// CHECK: [[v17:%.*]] = dist.init_dist_array parts [[v16]] : !ndarray.ndarray to !ndarray.ndarray> -// CHECK: [[v18:%.*]] = "dist.ewbin"([[v10]], [[v17]], [[vresultOffsets_4]], [[vresultSizes_5]], [[vt_offsets]]) <{op = 21 : i32}> : (!ndarray.ndarray<10xi64, #dist.dist_env>, !ndarray.ndarray>, index, index, index) -> !ndarray.ndarray<10xi64, #dist.dist_env> -// CHECK: [[v19:%.*]] = "dist.ewbin"([[v15]], [[v18]], [[vresultOffsets_4]], [[vresultSizes_5]], [[vt_offsets]]) <{op = 0 : i32}> : (!ndarray.ndarray<10xi64, #dist.dist_env>, !ndarray.ndarray<10xi64, #dist.dist_env>, index, index, index) -> !ndarray.ndarray<10xi64, #dist.dist_env> diff --git a/test/Dialect/NDArray/Transforms/NDArrayDist.mlir b/test/Dialect/NDArray/Transforms/NDArrayDist.mlir deleted file mode 100644 index 03ced4c92..000000000 --- a/test/Dialect/NDArray/Transforms/NDArrayDist.mlir +++ /dev/null @@ -1,60 +0,0 @@ -// RUN: imex-opt --split-input-file --ndarray-dist %s -verify-diagnostics -o -| FileCheck %s - -func.func @test_linspace(%arg0: i64, %arg1: i64, %arg2: i64) -> i64 { - %c0 = arith.constant 0 : index - %c3 = arith.constant 3 : index - %c33 = arith.constant 33 : i64 - %c22 = arith.constant 22 : index - %v = arith.constant 55 : i64 - %s = arith.index_cast %arg0 : i64 to index - %0 = ndarray.linspace %arg0 %arg1 %c33 false {team = 1} : (i64, i64, i64) -> !ndarray.ndarray<33xi64, #dist.dist_env> - %1 = ndarray.create %c22 value %v {team = 1, dtype = 2 : i8} : (index, i64) -> !ndarray.ndarray> - %10 = ndarray.subview %0[%c0][22][%c3] : !ndarray.ndarray<33xi64, #dist.dist_env> to !ndarray.ndarray> - %20 = ndarray.ewbin %10, %1 {op = 0 : i32} : (!ndarray.ndarray>, !ndarray.ndarray>) -> !ndarray.ndarray> - %21 = ndarray.reduction %20 {op = 4 : i32} : !ndarray.ndarray> -> !ndarray.ndarray> - %30 = builtin.unrealized_conversion_cast %21 : !ndarray.ndarray> to i64 - return %30 : i64 -} -// CHECK-LABEL: func.func @test_linspace -// CHECK: arith.constant -// CHECK: arith.constant -// CHECK: arith.constant -// CHECK: arith.constant -// CHECK: arith.constant -// CHECK: ndarray.linspace -// CHECK: ndarray.create -// CHECK: dist.subview -// CHECK: dist.repartition -// CHECK: dist.repartition -// CHECK: "dist.ewbin" -// CHECK: ndarray.reduction - -// ----- -func.func @test_dim(%arg0: !ndarray.ndarray<10x20xi64, #dist.dist_env>) -> index { - %c0 = arith.constant 0 : index - %1 = ndarray.dim %arg0 %c0 : !ndarray.ndarray<10x20xi64, #dist.dist_env> -> index - return %1 : index -} -// CHECK-LABEL: func.func @test_dim -// CHECK-NEXT: [[V:%.*]] = arith.constant 10 : index -// CHECK-NEXT: return [[V]] : index - -// ----- -func.func @test_ewuny(%arg0: !ndarray.ndarray<11xf64, #dist.dist_env>) -> !ndarray.ndarray<11xf64, #dist.dist_env> { - %0 ="ndarray.ewuny"(%arg0) {op = 0 : i32} : (!ndarray.ndarray<11xf64, #dist.dist_env>) -> !ndarray.ndarray> - %1 = builtin.unrealized_conversion_cast %0 : !ndarray.ndarray> to !ndarray.ndarray<11xf64, #dist.dist_env> - return %1 : !ndarray.ndarray<11xf64, #dist.dist_env> -} -// CHECK-LABEL: func.func @test_ewuny -// CHECK: "dist.ewuny" - -// ----- -func.func @test_ewbin(%arg0: !ndarray.ndarray<11xf64, #dist.dist_env>, %arg1: !ndarray.ndarray<11xf64, #dist.dist_env>) -> !ndarray.ndarray<11xf64, #dist.dist_env> { - %0 = ndarray.ewbin %arg0, %arg1 {op = 0 : i32} : (!ndarray.ndarray<11xf64, #dist.dist_env>, !ndarray.ndarray<11xf64, #dist.dist_env>) -> !ndarray.ndarray> - %1 = builtin.unrealized_conversion_cast %0 : !ndarray.ndarray> to !ndarray.ndarray<11xf64, #dist.dist_env> - return %1 : !ndarray.ndarray<11xf64, #dist.dist_env> -} -// CHECK-LABEL: func.func @test_ewbin -// CHECK: [[V1:%.*]] = dist.repartition -// CHECK: [[V2:%.*]] = dist.repartition -// CHECK: "dist.ewbin"([[V1]], [[V2]])