Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MapExpr supports concat op, WIP MapExpr2Ir #80

Open
wants to merge 1 commit into
base: adt-dynamic-shape
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions paddle/cinn/adt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ gather_srcs(
cinnapi_src
SRCS
anchor_sd_equation_context.cc
dim_expr_simplifier.cc
dim_expr.cc
equation_function.cc
equation_solver.cc
equation_value.cc
Expand All @@ -17,6 +19,7 @@ gather_srcs(
naive_bidirection_equation_generator.cc
naive_op_equation_context.cc
partition_op_stmts.cc
print_dim_expr.cc
print_equations.cc
print_map_expr.cc
print_schedule_descriptor.cc
Expand All @@ -26,12 +29,9 @@ gather_srcs(
schedule_descriptor.cc
schedule_dim.cc
schedule_mesh.cc
dim_expr.cc
dim_expr_simplifier.cc
symbolic_dim_infer_util.cc
simplify_value.cc
write_broadcast_disabled_bidirection_equation_generator.cc
print_dim_expr.cc)
symbolic_dim_infer_util.cc
write_broadcast_disabled_bidirection_equation_generator.cc)

cinn_cc_test(equation_value_match_trait_test SRCS
equation_value_match_trait_test.cc DEPS gtest glog)
Expand Down
23 changes: 17 additions & 6 deletions paddle/cinn/adt/equation_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ CollectInputAndOutputVariables(const Function& function) {
out_variables.emplace(Variable{out_index.value()});
in_variables.emplace(Variable{in_index.value()});
},
[&](const IndexDot<List<DimExpr>,
tOut<Index>,
tIn<List<Iterator>>>& dot) {
[&](const IndexDot<List<DimExpr>, tOut<Index>, tIn<List<Iterator>>>&
dot) {
const auto& [dims, out_index, in_iterators] = dot.tuple();
out_variables.emplace(Variable{out_index.value()});
for (const auto& iterator : *in_iterators.value()) {
Expand All @@ -49,9 +48,17 @@ CollectInputAndOutputVariables(const Function& function) {
out_variables.emplace(Variable{out_iterator.value()});
in_variables.emplace(Variable{in_iterator.value()});
},
[&](const IndexUnDot<List<DimExpr>,
tOut<List<Iterator>>,
tIn<Index>>& undot) {
[&](const SubFunction<DimExpr, tOut<Iterator>, tIn<Iterator>>&
sub_function) {
{
const auto& [dim, out_iterator, in_iterator] =
sub_function.tuple();
out_variables.emplace(Variable{out_iterator.value()});
in_variables.emplace(Variable{in_iterator.value()});
}
},
[&](const IndexUnDot<List<DimExpr>, tOut<List<Iterator>>, tIn<Index>>&
undot) {
const auto& [dims, out_iterators, in_index] = undot.tuple();
for (const auto& iterator : *out_iterators.value()) {
out_variables.emplace(Variable{iterator});
Expand Down Expand Up @@ -113,6 +120,8 @@ std::string GetFunctionTypeName(const Function& function) {
tIn<Iterator>>& broadcast) {
return "GetBroadcastedIterator";
},
[&](const SubFunction<DimExpr, tOut<Iterator>, tIn<Iterator>>&
sub_function) { return "SubFunction"; },
[&](const IndexUnDot<List<DimExpr>,
tOut<List<Iterator>>,
tIn<Index>>& undot) { return "IndexUnDot"; },
Expand Down Expand Up @@ -142,6 +151,8 @@ const void* GetFunctionDataPtr(const Function& function) {
tOut<Iterator>,
tIn<Iterator>>& broadcast)
-> const void* { return &broadcast.tuple(); },
[&](const SubFunction<DimExpr, tOut<Iterator>, tIn<Iterator>>& sub)
-> const void* { return &sub.tuple(); },
[&](const IndexUnDot<List<DimExpr>,
tOut<List<Iterator>>,
tIn<Index>>& undot) -> const void* {
Expand Down
20 changes: 14 additions & 6 deletions paddle/cinn/adt/equation_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,27 @@ struct GetBroadcastedIterator<DimExpr, tOut<Iterator>, tIn<Iterator>>
using Tuple<DimExpr, tOut<Iterator>, tIn<Iterator>>::Tuple;
};

template <typename DimT, typename OutT, typename InT>
struct SubFunction;

template <>
struct SubFunction<DimExpr, tOut<Iterator>, tIn<Iterator>>
: public Tuple<DimExpr, tOut<Iterator>, tIn<Iterator>> {
using Tuple<DimExpr, tOut<Iterator>, tIn<Iterator>>::Tuple;
};

// clang-format off
DEFINE_ADT_UNION(Equation,
Identity<tOut<Iterator>, tIn<Iterator>>,
Identity<tOut<Index>, tIn<Index>>,
GetBroadcastedIterator<DimExpr,
tOut<Iterator>, tIn<Iterator>>,
IndexDot<List<DimExpr>, tOut<Index>,
tIn<List<Iterator>>>,
IndexUnDot<List<DimExpr>,
tOut<List<Iterator>>, tIn<Index>>,
SubFunction<DimExpr, tOut<Iterator>, tIn<Iterator>>,
IndexDot<List<DimExpr>, tOut<Index>, tIn<List<Iterator>>>,
IndexUnDot<List<DimExpr>, tOut<List<Iterator>>, tIn<Index>>,
InMsg2OutMsg<tOut<FakeOpPlaceHolder>,
tOut<OpArgIndexes<std::optional<Index>>>,
tIn<OpArgIndexes<Index>>>,
tOut<OpArgIndexes<std::optional<Index>>>,
tIn<OpArgIndexes<Index>>>,
ConstantFunction<tOut<Iterator>, tIn<Index>>);
// clang-format on

Expand Down
16 changes: 12 additions & 4 deletions paddle/cinn/adt/equation_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,25 @@ std::unordered_map<Variable, Value> InferValuesImpl(
}

std::unordered_map<Variable, Value> InferValuesImpl(
const IndexUnDot<List<DimExpr>, tOut<List<Iterator>>, tIn<Index>>&
undot,
const SubFunction<DimExpr, tOut<Iterator>, tIn<Iterator>>& sub,
IndexExprInferContext* ctx) {
const auto& [dim, out_iterator, in_iterator] = sub.tuple();
SubValue<Value, DimExpr> sub_iterator{ctx->GetValue(in_iterator.value()),
dim};
return {{out_iterator.value(), sub_iterator}};
}

std::unordered_map<Variable, Value> InferValuesImpl(
const IndexUnDot<List<DimExpr>, tOut<List<Iterator>>, tIn<Index>>& undot,
IndexExprInferContext* ctx) {
const auto& [dims, out_iters, in_index] = undot.tuple();

List<DimExpr> dim_constants{};
for (const auto& dim : *dims) {
dim_constants->emplace_back(dim);
}
IndexUnDotValue<Value, List<DimExpr>> index_undot{ctx->GetValue(in_index.value()),
dim_constants};
IndexUnDotValue<Value, List<DimExpr>> index_undot{
ctx->GetValue(in_index.value()), dim_constants};

std::unordered_map<Variable, Value> ret{};
for (std::size_t idx = 0; idx < out_iters.value()->size(); ++idx) {
Expand Down
14 changes: 14 additions & 0 deletions paddle/cinn/adt/equation_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ struct BroadcastedIterator final : public Tuple<ValueT, ConstantT> {
const ValueT& GetArg0() const { return std::get<0>(this->tuple()); }
};

template <typename ValueT, typename ConstantT>
struct SubValue final : public Tuple<ValueT, ConstantT> {
using Tuple<ValueT, ConstantT>::Tuple;

const ValueT& GetArg0() const { return std::get<0>(this->tuple()); }
};

DEFINE_ADT_UNION(Value,
Undefined,
Ok,
Expand All @@ -76,6 +83,7 @@ DEFINE_ADT_UNION(Value,
List<Value>,
IndexDotValue<Value, List<DimExpr>>,
IndexUnDotValue<Value, List<DimExpr>>,
SubValue<Value, DimExpr>,
ListGetItem<Value, DimExpr>,
BroadcastedIterator<Value, DimExpr>,
PtrGetItem<Value>);
Expand All @@ -89,6 +97,8 @@ using ListGetItem_Value_DimExpr = ListGetItem<Value, DimExpr>;
OVERLOAD_OPERATOR_EQ_NE(ListGetItem_Value_DimExpr, TupleEqual);
using BroadcastedIterator_Value_DimExpr = BroadcastedIterator<Value, DimExpr>;
OVERLOAD_OPERATOR_EQ_NE(BroadcastedIterator_Value_DimExpr, TupleEqual);
using SubValue_Value_DimExpr = SubValue<Value, DimExpr>;
OVERLOAD_OPERATOR_EQ_NE(SubValue_Value_DimExpr, TupleEqual);
OVERLOAD_OPERATOR_EQ_NE(PtrGetItem<Value>, TupleEqual);

inline std::size_t GetHashValue(const Value& value);
Expand Down Expand Up @@ -135,6 +145,10 @@ inline std::size_t GetHashValueImpl(
const auto& [v, c] = value.tuple();
return hash_combine(GetHashValue(v), GetHashValue(c));
}
inline std::size_t GetHashValueImpl(const SubValue<Value, DimExpr>& value) {
const auto& [v, c] = value.tuple();
return hash_combine(GetHashValue(v), GetHashValue(c));
}
inline std::size_t GetHashValueImpl(const PtrGetItem<Value>& value) {
const auto& [pointer, c] = value.tuple();
return hash_combine(pointer.value().unique_id(), GetHashValue(c));
Expand Down
41 changes: 21 additions & 20 deletions paddle/cinn/adt/equation_value_match_trait.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,37 +47,38 @@ struct MatchTrait<Value, List<T>> final {
}
};

#define DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2(name, type0, type1) \
template <typename T0, typename T1> \
struct MatchTrait<Value, name<T0, T1>> final { \
using base_type = name<type0, type1>; \
\
static constexpr int is_template = true; \
\
template <template <typename, typename> class Matcher> \
static bool MatchChildren(const base_type& value) { \
#define DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2(name, type0, type1) \
template <typename T0, typename T1> \
struct MatchTrait<Value, name<T0, T1>> final { \
using base_type = name<type0, type1>; \
\
static constexpr int is_template = true; \
\
template <template <typename, typename> class Matcher> \
static bool MatchChildren(const base_type& value) { \
return Matcher<T0, type0>::Call(std::get<0>(value.tuple())) && \
Matcher<T1, type1>::Call(std::get<1>(value.tuple())); \
} \
} \
};

DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2(ListGetItem, Value, DimExpr);
DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2(BroadcastedIterator, Value, DimExpr);
DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2(SubValue, Value, DimExpr);
DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2(IndexDotValue, Value, List<DimExpr>);
DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2(IndexUnDotValue, Value, List<DimExpr>);
#undef DEFINE_MATCH_TRAIT_VALUE_UNION_ARGSIZE_2

#define DEFINE_ADT_MATCH_TRAIT_EQUATION(name) \
template <typename T> \
struct MatchTrait<Value, name<T>> final { \
using base_type = name<Value>; \
\
static constexpr int is_template = true; \
\
template <template <typename, typename> class Matcher> \
static bool MatchChildren(const base_type& value) { \
#define DEFINE_ADT_MATCH_TRAIT_EQUATION(name) \
template <typename T> \
struct MatchTrait<Value, name<T>> final { \
using base_type = name<Value>; \
\
static constexpr int is_template = true; \
\
template <template <typename, typename> class Matcher> \
static bool MatchChildren(const base_type& value) { \
return Matcher<T, Value>::Call(std::get<0>(value.tuple())); \
} \
} \
};

DEFINE_ADT_MATCH_TRAIT_EQUATION(PtrGetItem);
Expand Down
10 changes: 8 additions & 2 deletions paddle/cinn/adt/m_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,15 @@ void CollectTensorIndexIteratorsImpl(
}

void CollectTensorIndexIteratorsImpl(
const BroadcastedIterator<Value, DimExpr>& broadcasted_iterator,
const BroadcastedIterator<Value, DimExpr>& tensor_index_expr,
std::unordered_set<Iterator>* ret) {
CollectTensorIndexIterators(broadcasted_iterator.GetArg0(), ret);
CollectTensorIndexIterators(tensor_index_expr.GetArg0(), ret);
}

void CollectTensorIndexIteratorsImpl(
const SubValue<Value, DimExpr>& tensor_index_expr,
std::unordered_set<Iterator>* ret) {
CollectTensorIndexIterators(tensor_index_expr.GetArg0(), ret);
}

void CollectTensorIndexIteratorsImpl(const PtrGetItem<Value>& tensor_index_expr,
Expand Down
36 changes: 29 additions & 7 deletions paddle/cinn/adt/naive_op_equation_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,22 @@ class NaiveOpEquationContext final : public OpEquationContext {
return input_tensor_iterator;
}

Iterator Sub(const Iterator& out_tensor_iterator,
const DimExpr& dim) override {
Iterator input_tensor_iterator{UniqueId::New()};
using Function = SubFunction<DimExpr, tOut<Iterator>, tIn<Iterator>>;
equations_->emplace_back(
Function{dim, input_tensor_iterator, out_tensor_iterator});
return input_tensor_iterator;
}

Iterator Identity(const Iterator& iterator) override {
Iterator new_iterator{UniqueId::New()};
using Function = cinn::adt::Identity<tOut<Iterator>, tIn<Iterator>>;
equations_->emplace_back(Function{new_iterator, iterator});
return new_iterator;
}

Iterator MakeConstantIterator(std::size_t constant,
Equations* equations) const {
using ConstF = ConstantFunction<tOut<Iterator>, tIn<Index>>;
Expand Down Expand Up @@ -227,8 +243,8 @@ class NaiveOpEquationContext final : public OpEquationContext {
}

std::optional<DimExpr> GetSymbolicDimSize(bool is_out,
std::size_t arg_idx,
std::size_t axis) const {
std::size_t arg_idx,
std::size_t axis) const {
const auto* Get = (is_out ? &GetSymbolicOutDim_ : &GetSymbolicInDim_);
const auto& opt_dim = (*Get)(arg_idx, axis);
return opt_dim;
Expand All @@ -247,7 +263,8 @@ class NaiveOpEquationContext final : public OpEquationContext {
}
}

void InitInputDimExpr(std::vector<DimTuple>* vec, const std::vector<std::uint64_t>& tensors_ranks) {
void InitInputDimExpr(std::vector<DimTuple>* vec,
const std::vector<std::uint64_t>& tensors_ranks) {
for (std::size_t i = 0; i < tensors_ranks.size(); ++i) {
vec->push_back(DimTuple{});
for (std::size_t j = 0; j < tensors_ranks.at(i); ++j) {
Expand All @@ -258,7 +275,8 @@ class NaiveOpEquationContext final : public OpEquationContext {
}
}

void InitOutputDimExpr(std::vector<DimTuple>* vec, const std::vector<std::uint64_t>& tensors_ranks) {
void InitOutputDimExpr(std::vector<DimTuple>* vec,
const std::vector<std::uint64_t>& tensors_ranks) {
for (std::size_t i = 0; i < tensors_ranks.size(); ++i) {
vec->push_back(DimTuple{});
for (std::size_t j = 0; j < tensors_ranks.at(i); ++j) {
Expand Down Expand Up @@ -303,8 +321,8 @@ class NaiveOpEquationContext final : public OpEquationContext {

template <typename T>
void Equal(const T& lhs, const T& rhs) {
equations_->emplace_back(Identity<tOut<T>, tIn<T>>(lhs, rhs));
equations_->emplace_back(Identity<tOut<T>, tIn<T>>(rhs, lhs));
equations_->emplace_back(cinn::adt::Identity<tOut<T>, tIn<T>>(lhs, rhs));
equations_->emplace_back(cinn::adt::Identity<tOut<T>, tIn<T>>(rhs, lhs));
}

static std::optional<std::size_t> FindPos(const List<Index>& vector,
Expand All @@ -317,13 +335,17 @@ class NaiveOpEquationContext final : public OpEquationContext {
return std::nullopt;
}

const utils::Attribute& GetAttribute(const std::string& name) const {
const utils::Attribute& GetAttribute(const std::string& name) const override {
const auto& iter = attr_map_type_->find(name);
CHECK(iter != attr_map_type_->end())
<< "Can't find Attribute with this name";
return iter->second;
}

bool HasAttr(const std::string& name) const override {
return attr_map_type_->find(name) != attr_map_type_->end();
}

std::vector<std::uint64_t> in_tensors_ranks_;
std::vector<std::uint64_t> out_tensors_ranks_;
GetArgStaticDimT GetInDim_;
Expand Down
13 changes: 10 additions & 3 deletions paddle/cinn/adt/op_equation_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

#include "glog/logging.h"
#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/adt/dim_expr.h"
#include "paddle/cinn/adt/equation.h"
#include "paddle/cinn/hlir/framework/node.h"

namespace cinn::adt::config {
Expand All @@ -43,8 +43,13 @@ class OpEquationContext {

virtual void Equal(const IteratorTuple& lhs, const IteratorTuple& rhs) = 0;

virtual Iterator GetBroadcastedInputIterator(const Iterator& out_iterator,
const DimExpr& dim) = 0;
virtual Iterator GetBroadcastedInputIterator(
const Iterator& out_tensor_iterator, const DimExpr& dim) = 0;

virtual Iterator Sub(const Iterator& out_tensor_iterator,
const DimExpr& dim) = 0;

virtual Iterator Identity(const Iterator& iterator) = 0;

virtual const IteratorTuple& GetInIteratorTuple(
std::size_t input_idx) const = 0;
Expand All @@ -65,6 +70,8 @@ class OpEquationContext {
return absl::get<T>(GetAttribute(name));
}

virtual bool HasAttr(const std::string& name) const = 0;

protected:
OpEquationContext() = default;

Expand Down
Loading