From d56abe33025a4089f4159062c4d0da92ac5b4f72 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Fri, 3 Nov 2023 12:45:25 +0000 Subject: [PATCH] MapExpr supports concat op, WIP MapExpr2Ir --- paddle/cinn/adt/CMakeLists.txt | 10 ++-- paddle/cinn/adt/equation_function.cc | 23 ++++++-- paddle/cinn/adt/equation_function.h | 20 +++++-- paddle/cinn/adt/equation_solver.cc | 16 ++++-- paddle/cinn/adt/equation_value.h | 14 +++++ paddle/cinn/adt/equation_value_match_trait.h | 41 +++++++------- paddle/cinn/adt/m_ir.cc | 10 +++- paddle/cinn/adt/naive_op_equation_context.h | 36 +++++++++--- paddle/cinn/adt/op_equation_context.h | 13 ++++- paddle/cinn/adt/print_equations.cc | 13 ++++- paddle/cinn/adt/print_value.cc | 10 +++- paddle/cinn/hlir/op/transform.cc | 44 +++++++++++++++ test/cinn/adt/test_naive_concat.py | 59 ++++++++++++++++++++ 13 files changed, 252 insertions(+), 57 deletions(-) create mode 100644 test/cinn/adt/test_naive_concat.py diff --git a/paddle/cinn/adt/CMakeLists.txt b/paddle/cinn/adt/CMakeLists.txt index 32987eafd8f64..35ccd477151b9 100644 --- a/paddle/cinn/adt/CMakeLists.txt +++ b/paddle/cinn/adt/CMakeLists.txt @@ -4,6 +4,8 @@ gather_srcs( cinnapi_src SRCS anchor_sd_equation_context.cc + dim_expr_simplifier.cc + dim_expr.cc equation_function.cc equation_solver.cc equation_value.cc @@ -17,6 +19,7 @@ gather_srcs( naive_bidirection_equation_generator.cc naive_op_equation_context.cc partition_op_stmts.cc + print_dim_expr.cc print_equations.cc print_map_expr.cc print_schedule_descriptor.cc @@ -26,12 +29,9 @@ gather_srcs( schedule_descriptor.cc schedule_dim.cc schedule_mesh.cc - dim_expr.cc - dim_expr_simplifier.cc - symbolic_dim_infer_util.cc simplify_value.cc - write_broadcast_disabled_bidirection_equation_generator.cc - print_dim_expr.cc) + symbolic_dim_infer_util.cc + write_broadcast_disabled_bidirection_equation_generator.cc) cinn_cc_test(equation_value_match_trait_test SRCS equation_value_match_trait_test.cc DEPS gtest glog) diff --git a/paddle/cinn/adt/equation_function.cc b/paddle/cinn/adt/equation_function.cc index f69ae290e4c9c..2fdc9a8809ce2 100644 --- a/paddle/cinn/adt/equation_function.cc +++ b/paddle/cinn/adt/equation_function.cc @@ -33,9 +33,8 @@ CollectInputAndOutputVariables(const Function& function) { out_variables.emplace(Variable{out_index.value()}); in_variables.emplace(Variable{in_index.value()}); }, - [&](const IndexDot, - tOut, - tIn>>& dot) { + [&](const IndexDot, tOut, tIn>>& + dot) { const auto& [dims, out_index, in_iterators] = dot.tuple(); out_variables.emplace(Variable{out_index.value()}); for (const auto& iterator : *in_iterators.value()) { @@ -49,9 +48,17 @@ CollectInputAndOutputVariables(const Function& function) { out_variables.emplace(Variable{out_iterator.value()}); in_variables.emplace(Variable{in_iterator.value()}); }, - [&](const IndexUnDot, - tOut>, - tIn>& undot) { + [&](const SubFunction, tIn>& + sub_function) { + { + const auto& [dim, out_iterator, in_iterator] = + sub_function.tuple(); + out_variables.emplace(Variable{out_iterator.value()}); + in_variables.emplace(Variable{in_iterator.value()}); + } + }, + [&](const IndexUnDot, tOut>, tIn>& + undot) { const auto& [dims, out_iterators, in_index] = undot.tuple(); for (const auto& iterator : *out_iterators.value()) { out_variables.emplace(Variable{iterator}); @@ -113,6 +120,8 @@ std::string GetFunctionTypeName(const Function& function) { tIn>& broadcast) { return "GetBroadcastedIterator"; }, + [&](const SubFunction, tIn>& + sub_function) { return "SubFunction"; }, [&](const IndexUnDot, tOut>, tIn>& undot) { return "IndexUnDot"; }, @@ -142,6 +151,8 @@ const void* GetFunctionDataPtr(const Function& function) { tOut, tIn>& broadcast) -> const void* { return &broadcast.tuple(); }, + [&](const SubFunction, tIn>& sub) + -> const void* { return &sub.tuple(); }, [&](const IndexUnDot, tOut>, tIn>& undot) -> const void* { diff --git a/paddle/cinn/adt/equation_function.h b/paddle/cinn/adt/equation_function.h index c864ed0ede6f9..b12c41b09d997 100644 --- a/paddle/cinn/adt/equation_function.h +++ b/paddle/cinn/adt/equation_function.h @@ -106,19 +106,27 @@ struct GetBroadcastedIterator, tIn> using Tuple, tIn>::Tuple; }; +template +struct SubFunction; + +template <> +struct SubFunction, tIn> + : public Tuple, tIn> { + using Tuple, tIn>::Tuple; +}; + // clang-format off DEFINE_ADT_UNION(Equation, Identity, tIn>, Identity, tIn>, GetBroadcastedIterator, tIn>, - IndexDot, tOut, - tIn>>, - IndexUnDot, - tOut>, tIn>, + SubFunction, tIn>, + IndexDot, tOut, tIn>>, + IndexUnDot, tOut>, tIn>, InMsg2OutMsg, - tOut>>, - tIn>>, + tOut>>, + tIn>>, ConstantFunction, tIn>); // clang-format on diff --git a/paddle/cinn/adt/equation_solver.cc b/paddle/cinn/adt/equation_solver.cc index 35f3ae533281e..adad52e091445 100644 --- a/paddle/cinn/adt/equation_solver.cc +++ b/paddle/cinn/adt/equation_solver.cc @@ -85,8 +85,16 @@ std::unordered_map InferValuesImpl( } std::unordered_map InferValuesImpl( - const IndexUnDot, tOut>, tIn>& - undot, + const SubFunction, tIn>& sub, + IndexExprInferContext* ctx) { + const auto& [dim, out_iterator, in_iterator] = sub.tuple(); + SubValue sub_iterator{ctx->GetValue(in_iterator.value()), + dim}; + return {{out_iterator.value(), sub_iterator}}; +} + +std::unordered_map InferValuesImpl( + const IndexUnDot, tOut>, tIn>& undot, IndexExprInferContext* ctx) { const auto& [dims, out_iters, in_index] = undot.tuple(); @@ -94,8 +102,8 @@ std::unordered_map InferValuesImpl( for (const auto& dim : *dims) { dim_constants->emplace_back(dim); } - IndexUnDotValue> index_undot{ctx->GetValue(in_index.value()), - dim_constants}; + IndexUnDotValue> index_undot{ + ctx->GetValue(in_index.value()), dim_constants}; std::unordered_map ret{}; for (std::size_t idx = 0; idx < out_iters.value()->size(); ++idx) { diff --git a/paddle/cinn/adt/equation_value.h b/paddle/cinn/adt/equation_value.h index a876ffef1bf6d..4cca35644ccc9 100644 --- a/paddle/cinn/adt/equation_value.h +++ b/paddle/cinn/adt/equation_value.h @@ -68,6 +68,13 @@ struct BroadcastedIterator final : public Tuple { const ValueT& GetArg0() const { return std::get<0>(this->tuple()); } }; +template +struct SubValue final : public Tuple { + using Tuple::Tuple; + + const ValueT& GetArg0() const { return std::get<0>(this->tuple()); } +}; + DEFINE_ADT_UNION(Value, Undefined, Ok, @@ -76,6 +83,7 @@ DEFINE_ADT_UNION(Value, List, IndexDotValue>, IndexUnDotValue>, + SubValue, ListGetItem, BroadcastedIterator, PtrGetItem); @@ -89,6 +97,8 @@ using ListGetItem_Value_DimExpr = ListGetItem; OVERLOAD_OPERATOR_EQ_NE(ListGetItem_Value_DimExpr, TupleEqual); using BroadcastedIterator_Value_DimExpr = BroadcastedIterator; OVERLOAD_OPERATOR_EQ_NE(BroadcastedIterator_Value_DimExpr, TupleEqual); +using SubValue_Value_DimExpr = SubValue; +OVERLOAD_OPERATOR_EQ_NE(SubValue_Value_DimExpr, TupleEqual); OVERLOAD_OPERATOR_EQ_NE(PtrGetItem, TupleEqual); inline std::size_t GetHashValue(const Value& value); @@ -135,6 +145,10 @@ inline std::size_t GetHashValueImpl( const auto& [v, c] = value.tuple(); return hash_combine(GetHashValue(v), GetHashValue(c)); } +inline std::size_t GetHashValueImpl(const SubValue& value) { + const auto& [v, c] = value.tuple(); + return hash_combine(GetHashValue(v), GetHashValue(c)); +} inline std::size_t GetHashValueImpl(const PtrGetItem& value) { const auto& [pointer, c] = value.tuple(); return hash_combine(pointer.value().unique_id(), GetHashValue(c)); diff --git a/paddle/cinn/adt/equation_value_match_trait.h b/paddle/cinn/adt/equation_value_match_trait.h index 9b21c1146c8e3..496abc71badcb 100644 --- a/paddle/cinn/adt/equation_value_match_trait.h +++ b/paddle/cinn/adt/equation_value_match_trait.h @@ -47,37 +47,38 @@ struct MatchTrait> final { } }; -#define DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2(name, type0, type1) \ - template \ - struct MatchTrait> final { \ - using base_type = name; \ - \ - static constexpr int is_template = true; \ - \ - template