From bf94356030830e3f27ba6fad6fe042c2b9b4e361 Mon Sep 17 00:00:00 2001 From: wangone <2279939962@qq.com> Date: Mon, 27 Sep 2021 12:24:52 +0800 Subject: [PATCH] add several primitive ops (#438) * meta op demo * add const_scalar and broadcast_to primitive ops --- cinn/frontend/syntax.cc | 134 +++++++++++++++- cinn/frontend/syntax.h | 87 +++++++++++ cinn/hlir/framework/graph.cc | 2 +- cinn/hlir/framework/graph_compiler.cc | 32 +++- cinn/hlir/op/CMakeLists.txt | 1 + cinn/hlir/op/broadcast.cc | 172 ++++++++++++--------- cinn/hlir/op/elementwise.cc | 197 ++++++++++++++++++++++-- cinn/hlir/op/nn.cc | 111 +++---------- cinn/hlir/op/op_util.cc | 1 + cinn/hlir/op/op_util.h | 18 +++ cinn/hlir/pass/CMakeLists.txt | 1 + cinn/hlir/pass/opfusion.cc | 7 +- cinn/hlir/pass/test_primitive_ops.cc | 92 +++++++++++ cinn/hlir/pe/broadcast.cc | 28 ++++ cinn/hlir/pe/broadcast.h | 13 +- cinn/hlir/pe/elementwise.cc | 3 +- cinn/hlir/pe/elementwise.h | 1 + tests/benchmark/test_all_ops_default.cc | 12 +- 18 files changed, 722 insertions(+), 190 deletions(-) mode change 100755 => 100644 cinn/frontend/syntax.cc create mode 100644 cinn/hlir/op/op_util.cc create mode 100644 cinn/hlir/op/op_util.h create mode 100644 cinn/hlir/pass/test_primitive_ops.cc diff --git a/cinn/frontend/syntax.cc b/cinn/frontend/syntax.cc old mode 100755 new mode 100644 index 784a1e5218..0865ae4f77 --- a/cinn/frontend/syntax.cc +++ b/cinn/frontend/syntax.cc @@ -107,6 +107,59 @@ Variable Program::batchnorm(const Variable& a, return instr.GetOutput(0); } +template +Variable Program::primitive_const_scalar(PrimType value, const std::string& name) { + Instruction instr("const_scalar"); + instr.SetInputs({}); + instr.SetAttr("value", value); + AppendInstruction(instr); + auto out = instr.GetOutput(0); + out.set_id(name); + auto out_type = type_of(); + CHECK(out_type.is_float() || out_type.is_int()) << "no supported type: " << out_type; + out->type = out_type; + return out; +} + +Variable Program::primitive_broadcast_to(const Variable& a, + const std::vector& out_shape, + const std::vector& broadcast_axes) { + Instruction instr("broadcast_to"); + instr.SetInputs({a}); + instr.SetAttr("out_shape", out_shape); + instr.SetAttr("broadcast_axes", broadcast_axes); + AppendInstruction(instr); + return instr.GetOutput(0); +} + +Variable Program::fused_batchnorm_inference(const Variable& a, + const Variable& scale, + const Variable& bias, + const Variable& mean, + const Variable& variance, + const std::unordered_map& attr_store) { + float epsilon = 0.00001f; + if (attr_store.find("epsilon") != attr_store.end()) { + epsilon = std::get(attr_store.at("epsilon")); + } + auto eps_var = primitive_const_scalar(epsilon, common::UniqName("epsilon")); + CHECK(!scale->shape.empty()) << "scale's shape is empty."; + auto broadcast_eps = primitive_broadcast_to(eps_var, scale->shape, {0}); + auto var_add_eps = add(variance, broadcast_eps); + auto rsrqt_var = primitive_rsqrt(var_add_eps); + auto new_scale = multiply(rsrqt_var, scale); + auto neg_mean = primitive_negative(mean); + auto new_shift = multiply(new_scale, neg_mean); + auto shift_bias = add(new_shift, bias); + CHECK(!a->shape.empty()) << "variable a's shape is empty."; + auto broadcast_new_scale = primitive_broadcast_to(new_scale, a->shape, {1}); + auto broadcast_shift_bias = primitive_broadcast_to(shift_bias, a->shape, {1}); + auto temp_out = multiply(broadcast_new_scale, a); + auto bn_out = add(temp_out, broadcast_shift_bias); + + return bn_out; +} + Variable Program::scale(const Variable& a, const std::unordered_map& attr_store) { Instruction instr("scale", {a}); for (auto& iter : attr_store) { @@ -198,6 +251,85 @@ Variable Program::add(const Variable& a, const Variable& b) { return instr.GetOutput(0); } +Variable Program::multiply(const Variable& a, const Variable& b) { + Instruction instr("elementwise_mul", {a, b}); + AppendInstruction(instr); + return instr.GetOutput(0); +} + +#define SYNTAX_PRIM_UNARY_IMPL(name__) \ + Variable Program::primitive_##name__(const Variable& a) { \ + Instruction instr(#name__, {a}); \ + AppendInstruction(instr); \ + return instr.GetOutput(0); \ + } + +SYNTAX_PRIM_UNARY_IMPL(exp); +SYNTAX_PRIM_UNARY_IMPL(erf); +SYNTAX_PRIM_UNARY_IMPL(sqrt); +SYNTAX_PRIM_UNARY_IMPL(log); +SYNTAX_PRIM_UNARY_IMPL(floor); +SYNTAX_PRIM_UNARY_IMPL(ceil); +SYNTAX_PRIM_UNARY_IMPL(round); +SYNTAX_PRIM_UNARY_IMPL(tanh); +SYNTAX_PRIM_UNARY_IMPL(log2); +SYNTAX_PRIM_UNARY_IMPL(log10); +SYNTAX_PRIM_UNARY_IMPL(trunc); +SYNTAX_PRIM_UNARY_IMPL(cos); +SYNTAX_PRIM_UNARY_IMPL(sin); +SYNTAX_PRIM_UNARY_IMPL(cosh); +SYNTAX_PRIM_UNARY_IMPL(tan); +SYNTAX_PRIM_UNARY_IMPL(sinh); +SYNTAX_PRIM_UNARY_IMPL(acos); +SYNTAX_PRIM_UNARY_IMPL(acosh); +SYNTAX_PRIM_UNARY_IMPL(asin); +SYNTAX_PRIM_UNARY_IMPL(asinh); +SYNTAX_PRIM_UNARY_IMPL(atan); +SYNTAX_PRIM_UNARY_IMPL(atanh); + +SYNTAX_PRIM_UNARY_IMPL(isnan); +SYNTAX_PRIM_UNARY_IMPL(isfinite); +SYNTAX_PRIM_UNARY_IMPL(isinf); +SYNTAX_PRIM_UNARY_IMPL(bitwise_not); + +SYNTAX_PRIM_UNARY_IMPL(negative); +SYNTAX_PRIM_UNARY_IMPL(identity); +SYNTAX_PRIM_UNARY_IMPL(logica_not); +SYNTAX_PRIM_UNARY_IMPL(sign); +SYNTAX_PRIM_UNARY_IMPL(abs); +SYNTAX_PRIM_UNARY_IMPL(rsqrt); + +#define SYNTAX_PRIM_BINARY_IMPL(name__) \ + Variable Program::primitive_##name__(const Variable& a, const Variable& b) { \ + Instruction instr(#name__, {a, b}); \ + AppendInstruction(instr); \ + return instr.GetOutput(0); \ + } + +SYNTAX_PRIM_BINARY_IMPL(substract) +SYNTAX_PRIM_BINARY_IMPL(divide) +SYNTAX_PRIM_BINARY_IMPL(floor_divide) +SYNTAX_PRIM_BINARY_IMPL(mod) +SYNTAX_PRIM_BINARY_IMPL(floor_mod) +SYNTAX_PRIM_BINARY_IMPL(max) +SYNTAX_PRIM_BINARY_IMPL(min) +SYNTAX_PRIM_BINARY_IMPL(power) +SYNTAX_PRIM_BINARY_IMPL(logical_and) +SYNTAX_PRIM_BINARY_IMPL(logical_or) +SYNTAX_PRIM_BINARY_IMPL(logical_xor) +SYNTAX_PRIM_BINARY_IMPL(greater) +SYNTAX_PRIM_BINARY_IMPL(less) +SYNTAX_PRIM_BINARY_IMPL(equal) +SYNTAX_PRIM_BINARY_IMPL(not_equal) +SYNTAX_PRIM_BINARY_IMPL(greater_equal) +SYNTAX_PRIM_BINARY_IMPL(less_equal) + +SYNTAX_PRIM_BINARY_IMPL(bitwise_or) +SYNTAX_PRIM_BINARY_IMPL(bitwise_xor) +SYNTAX_PRIM_BINARY_IMPL(bitwise_and) +SYNTAX_PRIM_BINARY_IMPL(left_shift) +SYNTAX_PRIM_BINARY_IMPL(right_shift) + Variable Program::elementwise_add(const Variable& a, const Variable& b, int axis) { Instruction instr("elementwise_add", {a, b}); instr.SetAttr("axis", axis); @@ -267,7 +399,7 @@ std::string _Instruction_::debug_string() const { ss << op_type; ss << "("; ss << utils::Join(input_names, ", "); - if (!attrs.empty()) ss << ", "; + if (!attrs.empty() && !input_names.empty()) ss << ", "; std::vector attr_strs; for (auto& attr : attrs) { diff --git a/cinn/frontend/syntax.h b/cinn/frontend/syntax.h index 88397ddb72..1d0d37e304 100644 --- a/cinn/frontend/syntax.h +++ b/cinn/frontend/syntax.h @@ -170,6 +170,12 @@ struct Program { : instrs_(std::move(instrs)), inputs_(std::move(inputs)) {} void SetInputs(const std::vector& xs); + + /** + * create scalar with the specific value and type + */ + template + Variable primitive_const_scalar(PrimType value, const std::string& name); /** * Add two variables. * @@ -178,6 +184,7 @@ struct Program { * @return The result. */ Variable add(const Variable& a, const Variable& b); + Variable multiply(const Variable& a, const Variable& b); /** * Multiply two matrix. @@ -190,6 +197,76 @@ struct Program { Variable mulbias( const Variable& a, const Variable& b, const Variable& c, int x_num_col_dims = 1, int y_num_col_dims = 1); +#define SYNTAX_PRIM_UNARY_DECL(name__) Variable primitive_##name__(const Variable& a); + + SYNTAX_PRIM_UNARY_DECL(exp); + SYNTAX_PRIM_UNARY_DECL(erf); + SYNTAX_PRIM_UNARY_DECL(sqrt); + SYNTAX_PRIM_UNARY_DECL(log); + SYNTAX_PRIM_UNARY_DECL(floor); + SYNTAX_PRIM_UNARY_DECL(ceil); + SYNTAX_PRIM_UNARY_DECL(round); + SYNTAX_PRIM_UNARY_DECL(tanh); + SYNTAX_PRIM_UNARY_DECL(log2); + SYNTAX_PRIM_UNARY_DECL(log10); + SYNTAX_PRIM_UNARY_DECL(trunc); + SYNTAX_PRIM_UNARY_DECL(cos); + SYNTAX_PRIM_UNARY_DECL(sin); + SYNTAX_PRIM_UNARY_DECL(cosh); + SYNTAX_PRIM_UNARY_DECL(tan); + SYNTAX_PRIM_UNARY_DECL(sinh); + SYNTAX_PRIM_UNARY_DECL(acos); + SYNTAX_PRIM_UNARY_DECL(acosh); + SYNTAX_PRIM_UNARY_DECL(asin); + SYNTAX_PRIM_UNARY_DECL(asinh); + SYNTAX_PRIM_UNARY_DECL(atan); + SYNTAX_PRIM_UNARY_DECL(atanh); + + SYNTAX_PRIM_UNARY_DECL(isnan); + SYNTAX_PRIM_UNARY_DECL(isfinite); + SYNTAX_PRIM_UNARY_DECL(isinf); + SYNTAX_PRIM_UNARY_DECL(bitwise_not); + + SYNTAX_PRIM_UNARY_DECL(negative); + SYNTAX_PRIM_UNARY_DECL(identity); + SYNTAX_PRIM_UNARY_DECL(logica_not); + SYNTAX_PRIM_UNARY_DECL(sign); + SYNTAX_PRIM_UNARY_DECL(abs); + SYNTAX_PRIM_UNARY_DECL(rsqrt); + +#define SYNTAX_PRIM_BINARY_DECL(name__) Variable primitive_##name__(const Variable& a, const Variable& b); + SYNTAX_PRIM_BINARY_DECL(substract) + SYNTAX_PRIM_BINARY_DECL(divide) + SYNTAX_PRIM_BINARY_DECL(floor_divide) + SYNTAX_PRIM_BINARY_DECL(mod) + SYNTAX_PRIM_BINARY_DECL(floor_mod) + SYNTAX_PRIM_BINARY_DECL(max) + SYNTAX_PRIM_BINARY_DECL(min) + SYNTAX_PRIM_BINARY_DECL(power) + SYNTAX_PRIM_BINARY_DECL(logical_and) + SYNTAX_PRIM_BINARY_DECL(logical_or) + SYNTAX_PRIM_BINARY_DECL(logical_xor) + SYNTAX_PRIM_BINARY_DECL(greater) + SYNTAX_PRIM_BINARY_DECL(less) + SYNTAX_PRIM_BINARY_DECL(equal) + SYNTAX_PRIM_BINARY_DECL(not_equal) + SYNTAX_PRIM_BINARY_DECL(greater_equal) + SYNTAX_PRIM_BINARY_DECL(less_equal) + + SYNTAX_PRIM_BINARY_DECL(bitwise_or) + SYNTAX_PRIM_BINARY_DECL(bitwise_xor) + SYNTAX_PRIM_BINARY_DECL(bitwise_and) + SYNTAX_PRIM_BINARY_DECL(left_shift) + SYNTAX_PRIM_BINARY_DECL(right_shift) + + // broadcast one operand to the target shape + // broadcast axes: the target axis which a's ith axis is mapped to + // Notes: a's dim should be one or same with the output dim mapped to. + // e.g. if a[64] broadcasts to out[1, 64, 112, 112], then out_shape is {1, 64, 112, 112} and broadcast_axes are {1} + Variable primitive_broadcast_to(const Variable& a, + const std::vector& out_shape, + const std::vector& broadcast_axes); + /** * Add two tensors element-wise. */ @@ -245,6 +322,16 @@ struct Program { const Variable& variance, const std::unordered_map& attr_store); + /** + * batchnorm composed of primitive ops + */ + Variable fused_batchnorm_inference(const Variable& a, + const Variable& scale, + const Variable& bias, + const Variable& mean, + const Variable& variance, + const std::unordered_map& attr_store); + Variable scale(const Variable& a, const std::unordered_map& attr_store); Variable softmax(const Variable& a, const std::unordered_map& attr_store); diff --git a/cinn/hlir/framework/graph.cc b/cinn/hlir/framework/graph.cc index 6688a2deb5..76ea01a476 100644 --- a/cinn/hlir/framework/graph.cc +++ b/cinn/hlir/framework/graph.cc @@ -27,8 +27,8 @@ Graph::Graph(const frontend::Program& prog, const Target& target) { graph_node->as()->LinkTo(node_tmp); } } + int out_idx = 0; for (auto& output_v : temp->outputs) { - int out_idx = 0; auto* output_data = new NodeData(node_ptr, out_idx++, 0, output_v->id); node_tmp->LinkTo(output_data); this->RegisterNode(output_v->id, output_data); diff --git a/cinn/hlir/framework/graph_compiler.cc b/cinn/hlir/framework/graph_compiler.cc index 6869b52f89..548fc3cd42 100644 --- a/cinn/hlir/framework/graph_compiler.cc +++ b/cinn/hlir/framework/graph_compiler.cc @@ -118,6 +118,19 @@ std::vector GraphCompiler::GetOpFunc(const Node* node) { return func; } +// get the most complex op's index in the fused groups according to the OpPattern. If the OpPattern is same, we will take the latter. +int GetMasterRefNode(const std::vector& nodes) { + auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); + int master_index = 0; + int master_pattern = op_pattern_dict[nodes[0]->op()]; + for (int i = 1; i < nodes.size(); i++) { + int pattern = op_pattern_dict[nodes[i]->op()]; + master_index = pattern >= master_pattern ? i : master_index; + } + VLOG(3) << "master_index: " << master_index << ", master op: " << nodes[master_index]->op()->name; + return master_index; +} + std::vector GraphCompiler::GetOpFunc(const std::vector& nodes) { CHECK_GT(nodes.size(), 1) << "fuse nodes number must be greater than 1"; auto& strategy = Operator::GetAttrs("CINNStrategy"); @@ -133,7 +146,8 @@ std::vector GraphCompiler::GetOpFunc(const std::vector& std::unordered_set in_vars; std::unordered_set out_vars; std::unordered_map temp_var_map; - ir::Tensor first_out_tensor; + ir::Tensor master_out_tensor; + int master_index = GetMasterRefNode(nodes); for (auto& node : nodes) { std::vector temp_inputs; std::vector cinn_inputs; @@ -181,12 +195,12 @@ std::vector GraphCompiler::GetOpFunc(const std::vector& OpStrategy::SelectImpl(strategy[node->op()](node->attrs, temp_inputs, out_types, output_shapes, target_)); common::CINNValuePack C = impl->fcompute(common::CINNValuePack{cinn_inputs}); - if (index == 0) { - // use the first op's schedule as the fused ops' schedule as complex op like conv appear in the first. + if (index == master_index) { + // use the most complex op's schedule as the fused ops' schedule. C = impl->fschedule(C); CHECK(!C.empty()); - Expr out = C[0]; - first_out_tensor = out.as_tensor_ref(); + Expr out = C[0]; + master_out_tensor = out.as_tensor_ref(); } CHECK_GE(C.size(), 2); CHECK_LE(C.size() - 1, node->outlinks_in_order().size()); @@ -237,8 +251,10 @@ std::vector GraphCompiler::GetOpFunc(const std::vector& inputs.insert(inputs.end(), outputs.begin(), outputs.end()); ir::Tensor final_out_tensor = outputs.front(); - stages[final_out_tensor]->CopyTransform(stages[first_out_tensor]); - stages[final_out_tensor]->CopyLoopInfo(stages[first_out_tensor]); + if (final_out_tensor->name != master_out_tensor->name) { + stages[final_out_tensor]->CopyTransform(stages[master_out_tensor]); + stages[final_out_tensor]->CopyLoopInfo(stages[master_out_tensor]); + } for (auto& s : stages) { auto& compute_ats = s.second->GetComputeAts(); @@ -255,7 +271,9 @@ std::vector GraphCompiler::GetOpFunc(const std::vector& new_relation.level = old_relation.level; compute_ats.clear(); + CHECK(new_relation.IsCompatible(s.second.get())) << "new computeAt should be compatible"; compute_ats[new_stage->id()] = new_relation; + break; } } } diff --git a/cinn/hlir/op/CMakeLists.txt b/cinn/hlir/op/CMakeLists.txt index 0dfc6207b0..2770a3530d 100644 --- a/cinn/hlir/op/CMakeLists.txt +++ b/cinn/hlir/op/CMakeLists.txt @@ -5,6 +5,7 @@ core_gather_srcs(SRCS broadcast.cc transform.cc elementwise.cc + op_util.cc ) cc_test(test_cinn_op_broadcast SRCS op_broadcast_test.cc DEPS cinncore) diff --git a/cinn/hlir/op/broadcast.cc b/cinn/hlir/op/broadcast.cc index 84c41b117b..8997abc187 100644 --- a/cinn/hlir/op/broadcast.cc +++ b/cinn/hlir/op/broadcast.cc @@ -5,6 +5,7 @@ #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/op_strategy.h" +#include "cinn/hlir/op/op_util.h" #include "cinn/hlir/pe/nn.h" #include "cinn/hlir/pe/schedule.h" #include "cinn/ir/ir_operators.h" @@ -65,16 +66,13 @@ std::shared_ptr StrategyForBroadcast( CHECK(!args.empty()) << "The input argument of " << op_name << " schedule is empty! Please check."; CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL); + Expr Out = arg_pack[0]; + poly::StageMap stages = arg_pack[1]; + CHECK(Out.as_tensor()); if (target.arch == Target::Arch::NVGPU) { - Expr Out = arg_pack[0]; - poly::StageMap stages = arg_pack[1]; - CHECK(Out.as_tensor()); - pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(), target); + pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.front(), target); } else if (target.arch == Target::Arch::X86) { - Expr Out = arg_pack[0]; - poly::StageMap stages = arg_pack[1]; - CHECK(Out.as_tensor()); - pe::ScheduleInjectiveCPU(stages[Out.as_tensor_ref()], output_shapes.back(), target); + pe::ScheduleInjectiveCPU(stages[Out.as_tensor_ref()], output_shapes.front(), target); } *ret = arg_pack; }); @@ -89,6 +87,7 @@ std::vector InferShapeForBroadcast(const std::vector &inputs_s const Target &target) { CHECK_EQ(inputs_shape.size(), 2UL); std::vector out_shape; + int axis = -1; for (auto &iter : attrs.attr_store) { if (iter.first == "axis") { @@ -146,88 +145,105 @@ std::vector> InferLayoutForBroadcast(const std::vector< } } -std::shared_ptr StrategyForScale(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - float scale = 1.f; - float bias = 0.f; - bool bias_after_scale = true; - for (auto &iter : attrs.attr_store) { - if (iter.first == "scale") { - scale = std::get(iter.second); - } else if (iter.first == "bias") { - bias = std::get(iter.second); - } else if (iter.first == "bias_after_scale") { - bias_after_scale = std::get(iter.second); - } +std::shared_ptr StrategyForBroadcastTo(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + std::vector out_shape; + std::vector broadcast_axes; + if (attrs.attr_store.count("out_shape")) { + out_shape = std::get>(attrs.attr_store.at("out_shape")); + } + if (attrs.attr_store.count("broadcast_axes")) { + broadcast_axes = std::get>(attrs.attr_store.at("broadcast_axes")); } - framework::CINNCompute scale_compute([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of scale compute is empty! Please check."; + + framework::CINNCompute broadcast_to_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of broadcast_to compute is empty! Please check."; CINNValuePack a = args[0]; - CHECK(!a.empty()) << "The input tensors of scale compute is empty! Please check."; + CHECK(!a.empty()) << "The input tensors of broadcast_to compute is empty! Please check."; Expr A_expr = a[0]; CHECK(A_expr.as_tensor()); ir::Tensor A = A_expr.as_tensor_ref(); - ir::Tensor out; - if (bias_after_scale) { - out = Compute( - A->shape, [=](const std::vector &indice) { return scale * A(indice) + bias; }, UniqName("Scale_out")); - } else { - out = Compute( - A->shape, [=](const std::vector &indice) { return scale * (A(indice) + bias); }, UniqName("Scale_out")); - } - auto stages = CreateStages({out}); - *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; + auto out = BroadcastTo(A, out_shape, broadcast_axes, UniqName("broadcast_to_Out")); + auto stages = CreateStages({A, out}); + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; }); - framework::CINNSchedule scale_schedule([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input arguments of scale schedule is empty! Please check."; + framework::CINNSchedule broadcast_to_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of broadcast_to schedule is empty! Please check."; CINNValuePack arg_pack = args[0]; - CHECK_EQ(arg_pack.size(), 2UL) << "The input tensor's size of scale schedule is " << arg_pack.size() - << "and it should be equal to 2! Please check."; + CHECK_EQ(arg_pack.size(), 2UL); + Expr Out = arg_pack[0]; + poly::StageMap stages = arg_pack.back(); + CHECK(Out.as_tensor()); if (target.arch == Target::Arch::NVGPU) { - Expr Out = arg_pack[0]; - poly::StageMap stages = arg_pack[1]; - CHECK(Out.as_tensor()); - pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(), target); + pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], out_shape, target); + } else if (target.arch == Target::Arch::X86) { + pe::ScheduleInjectiveCPU(stages[Out.as_tensor_ref()], out_shape, target); } *ret = arg_pack; }); auto strategy = std::make_shared(); - strategy->AddImpl(scale_compute, scale_schedule, "strategy.scale.x86", 1); + strategy->AddImpl(broadcast_to_compute, broadcast_to_schedule, "strategy.broadcast_to.x86", 1); return strategy; } -std::vector InferShapeForScale(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; - return {{inputs_shape[0]}}; -} +std::vector InferShapeForBroadcastTo(const std::vector &inputs_shape, + framework::NodeAttr &attrs, + const Target &target) { + CHECK_EQ(inputs_shape.size(), 1UL) << "input_shape size should be one. Please Check."; + std::vector broadcast_axes; + std::vector out_shape; + CHECK(attrs.attr_store.count("broadcast_axes")); + CHECK(attrs.attr_store.count("out_shape")); + out_shape = std::get>(attrs.attr_store.at("out_shape")); + broadcast_axes = std::get>(attrs.attr_store.at("broadcast_axes")); -std::vector InferDtypeForScale(const std::vector &inputs_type, - framework::NodeAttr &attrs, - const Target &target) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; - std::vector res{inputs_type[0]}; - return res; + CHECK_EQ(inputs_shape[0].size(), broadcast_axes.size()) + << "broadcast_axes's size should be same with the input shape's size"; + CHECK_GE(out_shape.size(), broadcast_axes.size()) << "broadcast_axes's size should be no more than out_shape's size"; + + VLOG(3) << "broadcast out shape: " << utils::Join(out_shape, ", "); + return {out_shape}; } -std::vector> InferLayoutForScale(const std::vector &input_shapes, - const std::vector &input_layouts, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK_EQ(input_layouts.size(), 1U) << "The input's layouts size is not 1! Please check again."; - return {input_layouts, input_layouts}; +std::vector> InferLayoutForBroadcastTo(const std::vector> &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK(input_layouts.size() == 1U) << "The input's layouts size is not 1! Please check again."; + std::vector out_layouts = {""}; + if (attrs.attr_store.count("out_layouts")) { + out_layouts = std::get>(attrs.attr_store.at("out_layouts")); + } + return {out_layouts, input_layouts}; } StrategyForBinary(elementwise_add, Add); StrategyForBinary(elementwise_mul, Multiply); +StrategyForBinary(substract, Substract); +StrategyForBinary(divide, Divide); +StrategyForBinary(floor_divide, FloorDivide); +StrategyForBinary(mod, Mod); +StrategyForBinary(floor_mod, FloorMod); +StrategyForBinary(max, Maximum); +StrategyForBinary(min, Minimum); +StrategyForBinary(power, Power); +StrategyForBinary(logical_and, LogicaAnd); +StrategyForBinary(logical_or, LogicalOr); +StrategyForBinary(logical_xor, LogicalXOr); +StrategyForBinary(greater, Greater); +StrategyForBinary(less, Less); +StrategyForBinary(equal, Equal); +StrategyForBinary(not_equal, NotEqual); +StrategyForBinary(greater_equal, GreaterEqual); +StrategyForBinary(less_equal, LessEqual); + StrategyForBinary(bitwise_or, BitwiseOr); StrategyForBinary(bitwise_xor, BitwiseXor); StrategyForBinary(bitwise_and, BitwiseAnd); @@ -256,6 +272,24 @@ CINN_REGISTER_HELPER(broadcast_ops) { CINN_REGISTER_BINARY(elementwise_add, Add); CINN_REGISTER_BINARY(elementwise_mul, Multiply); + CINN_REGISTER_BINARY(substract, Substract); + CINN_REGISTER_BINARY(divide, Divide); + CINN_REGISTER_BINARY(floor_divide, FloorDivide); + CINN_REGISTER_BINARY(mod, Mod); + CINN_REGISTER_BINARY(floor_mod, FloorMod); + CINN_REGISTER_BINARY(max, Maximum); + CINN_REGISTER_BINARY(min, Minimum); + CINN_REGISTER_BINARY(power, Power); + CINN_REGISTER_BINARY(logical_and, LogicaAnd); + CINN_REGISTER_BINARY(logical_or, LogicalOr); + CINN_REGISTER_BINARY(logical_not, LogicalXOr); + CINN_REGISTER_BINARY(greater, Greater); + CINN_REGISTER_BINARY(less, Less); + CINN_REGISTER_BINARY(equal, Equal); + CINN_REGISTER_BINARY(not_equal, NotEqual); + CINN_REGISTER_BINARY(greater_equal, GreaterEqual); + CINN_REGISTER_BINARY(less_equal, LessEqual); + CINN_REGISTER_BINARY(bitwise_or, BitwiseOr); CINN_REGISTER_BINARY(bitwise_xor, BitwiseXor); CINN_REGISTER_BINARY(bitwise_and, BitwiseAnd); @@ -263,15 +297,15 @@ CINN_REGISTER_HELPER(broadcast_ops) { CINN_REGISTER_BINARY(right_shift, RightShift); #undef CINN_REGISTER_BINARY - CINN_REGISTER_OP(scale) - .describe("Putting scale and bias to the input Tensor") + CINN_REGISTER_OP(broadcast_to) + .describe("broadcast one tensor to the target shape") .set_num_inputs(1) .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForScale) - .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForScale)) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForBroadcastTo) + .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForBroadcastTo)) .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForBroadcast)) #ifndef CINN_WITH_CUDA - .set_attr("inferlayout", std::function(cinn::hlir::op::InferLayoutForScale)) + .set_attr("inferlayout", std::function(cinn::hlir::op::InferLayoutForBroadcastTo)) #endif .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElemWise) .set_support_level(4); diff --git a/cinn/hlir/op/elementwise.cc b/cinn/hlir/op/elementwise.cc index 9ef3808ee9..289771c078 100644 --- a/cinn/hlir/op/elementwise.cc +++ b/cinn/hlir/op/elementwise.cc @@ -58,21 +58,14 @@ std::shared_ptr StrategyForElementwise(const framework::NodeAttr &at framework::CINNSchedule unary_schedule([=](lang::Args args, lang::RetValue *ret) { CHECK(!args.empty()) << "The input argument of " << op_name << " schedule is empty! Please check."; CINNValuePack arg_pack = args[0]; - CHECK(arg_pack.size() == 2UL || arg_pack.size() == 3UL); + CHECK_EQ(arg_pack.size(), 2UL); + Expr Out = arg_pack[0]; + poly::StageMap stages = arg_pack[1]; + CHECK(Out.as_tensor()); if (target.arch == Target::Arch::NVGPU) { - Expr Out = arg_pack[0]; - poly::StageMap stages = arg_pack.back(); - CHECK(Out.as_tensor()); - pe::CudaSplitSchedule(stages[Out.as_tensor_ref()], output_shapes.back()); - if (Out.as_tensor()->shape.size() > 1) { - stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x"); - stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x"); - } + pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.front(), target); } else if (target.arch == Target::Arch::X86) { - Expr Out = arg_pack[0]; - poly::StageMap stages = arg_pack[1]; - CHECK(Out.as_tensor()); - pe::ScheduleInjectiveCPU(stages[Out.as_tensor_ref()], output_shapes.back(), target); + pe::ScheduleInjectiveCPU(stages[Out.as_tensor_ref()], output_shapes.front(), target); } *ret = arg_pack; }); @@ -107,6 +100,141 @@ std::vector> InferLayoutForElementwise(const std::vecto return {input_layouts, input_layouts}; } +std::shared_ptr StrategyForScale(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + float scale = 1.f; + float bias = 0.f; + bool bias_after_scale = true; + for (auto &iter : attrs.attr_store) { + if (iter.first == "scale") { + scale = std::get(iter.second); + } else if (iter.first == "bias") { + bias = std::get(iter.second); + } else if (iter.first == "bias_after_scale") { + bias_after_scale = std::get(iter.second); + } + } + framework::CINNCompute scale_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of scale compute is empty! Please check."; + CINNValuePack a = args[0]; + CHECK(!a.empty()) << "The input tensors of scale compute is empty! Please check."; + Expr A_expr = a[0]; + CHECK(A_expr.as_tensor()); + ir::Tensor A = A_expr.as_tensor_ref(); + ir::Tensor out; + if (bias_after_scale) { + out = Compute( + A->shape, [=](const std::vector &indice) { return scale * A(indice) + bias; }, UniqName("Scale_out")); + } else { + out = Compute( + A->shape, [=](const std::vector &indice) { return scale * (A(indice) + bias); }, UniqName("Scale_out")); + } + auto stages = CreateStages({out}); + *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; + }); + + framework::CINNSchedule scale_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input arguments of scale schedule is empty! Please check."; + CINNValuePack arg_pack = args[0]; + CHECK_EQ(arg_pack.size(), 2UL) << "The input tensor's size of scale schedule is " << arg_pack.size() + << "and it should be equal to 2! Please check."; + Expr Out = arg_pack[0]; + poly::StageMap stages = arg_pack[1]; + CHECK(Out.as_tensor()); + if (target.arch == Target::Arch::NVGPU) { + pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.front(), target); + } else if (target.arch == Target::Arch::X86) { + pe::ScheduleInjectiveCPU(stages[Out.as_tensor_ref()], output_shapes.front(), target); + } + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(scale_compute, scale_schedule, "strategy.scale.x86", 1); + + return strategy; +} + +Expr GetScalarExpr(const framework::NodeAttr::attr_t &attr) { + Expr scalar; + struct Visitor { + Expr &scalar_; + explicit Visitor(Expr &scalar) : scalar_(scalar) {} + void operator()(float v) { scalar_ = Expr(v); } + void operator()(int v) { scalar_ = Expr(v); } + void operator()(bool v) { scalar_ = Expr(v); } + void operator()(const std::string &v) { scalar_ = Expr(v); } + void operator()(const std::vector &) { LOG(FATAL) << "wrong type std::vector"; } + void operator()(const std::vector &) { LOG(FATAL) << "wrong type std::vector"; } + void operator()(const std::vector &) { LOG(FATAL) << "wrong type std::vector"; } + void operator()(const std::vector &) { LOG(FATAL) << "wrong type std::vector"; } + }; + std::visit(Visitor{scalar}, attr); + return scalar; +} + +std::shared_ptr StrategyForConstScalar(const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute const_scalar_compute([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of const_float compute is empty! Please check."; + auto scalar = GetScalarExpr(attrs.attr_store.at("value")); + auto out = lang::Compute( + {Expr(1)}, [=](const std::vector &indice) { return scalar; }, UniqName("const_scalar_Out")); + CHECK(out.defined()) << "can't create const scalar with the given type " << out_type[0]; + auto stages = CreateStages({out}); + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + }); + + framework::CINNSchedule const_scalar_schedule([=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) << "The input argument of create_const_float schedule is empty! Please check."; + CINNValuePack arg_pack = args[0]; + CHECK_EQ(arg_pack.size(), 2UL); + Expr Out = arg_pack[0]; + poly::StageMap stages = arg_pack.back(); + CHECK(Out.as_tensor()); + if (target.arch == Target::Arch::NVGPU) { + pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.front(), target); + } else if (target.arch == Target::Arch::X86) { + pe::ScheduleInjectiveCPU(stages[Out.as_tensor_ref()], output_shapes.front(), target); + } + *ret = arg_pack; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(const_scalar_compute, const_scalar_schedule, "strategy.const_scalar.x86", 1); + + return strategy; +} + +std::vector InferShapeForConstScalar(const std::vector &inputs_shape, + framework::NodeAttr &attrs, + const Target &target) { + return {{1}}; +} + +std::vector InferDtypeForConstScalar(const std::vector &inputs_type, + const framework::NodeAttr &attrs, + const Target &target) { + CHECK(attrs.attr_store.count("value")); + auto scalar = GetScalarExpr(attrs.attr_store.at("value")); + auto out_type = scalar->type(); + VLOG(3) << "scalar type: " << out_type; + return {out_type}; +} + +std::vector> InferLayoutForConstScalar(const std::vector &input_shapes, + const std::vector &input_layouts, + const framework::NodeAttr &attrs, + const Target &target) { + return {{"C"}, input_layouts}; +} + StrategyForUnary(exp, Exp); StrategyForUnary(erf, Erf); StrategyForUnary(sqrt, Sqrt); @@ -135,6 +263,14 @@ StrategyForUnary(isfinite, IsFinite); StrategyForUnary(isinf, IsInf); StrategyForUnary(bitwise_not, BitwiseNot); +StrategyForUnary(negative, Negative); +StrategyForUnary(identity, Identity); +StrategyForUnary(logica_not, LogicalNot); +StrategyForUnary(sign, Sign); +StrategyForUnary(abs, Abs); +StrategyForUnary(rsqrt, Rsqrt); +StrategyForUnary(sigmoid, Sigmoid); + #undef StrategyForUnary } // namespace op @@ -181,7 +317,42 @@ CINN_REGISTER_HELPER(elementwise_ops) { CINN_REGISTER_UNARY(isfinite, IsFinite) CINN_REGISTER_UNARY(isinf, IsInf) CINN_REGISTER_UNARY(bitwise_not, BitwiseNot) + + CINN_REGISTER_UNARY(negative, Negative) + CINN_REGISTER_UNARY(identity, Identity) + CINN_REGISTER_UNARY(logica_not, LogicalNot) + CINN_REGISTER_UNARY(sign, Sign) + CINN_REGISTER_UNARY(abs, Abs) + CINN_REGISTER_UNARY(rsqrt, Rsqrt) + CINN_REGISTER_UNARY(sigmoid, Sigmoid) + #undef CINN_REGISTER_UNARY + CINN_REGISTER_OP(scale) + .describe("Putting scale and bias to the input Tensor") + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForScale) + .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForElementwise)) + .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForElementwise)) +#ifndef CINN_WITH_CUDA + .set_attr("inferlayout", std::function(cinn::hlir::op::InferLayoutForElementwise)) +#endif + .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElemWise) + .set_support_level(4); + + CINN_REGISTER_OP(const_scalar) + .describe("create const scalar with the given value") + .set_num_inputs(0) + .set_num_outputs(1) + .set_attr("CINNStrategy", cinn::hlir::op::StrategyForConstScalar) + .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForConstScalar)) + .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForConstScalar)) +#ifndef CINN_WITH_CUDA + .set_attr("inferlayout", std::function(cinn::hlir::op::InferLayoutForConstScalar)) +#endif + .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElemWise) + .set_support_level(4); + return true; } diff --git a/cinn/hlir/op/nn.cc b/cinn/hlir/op/nn.cc index 1de7f447c4..06580e1034 100644 --- a/cinn/hlir/op/nn.cc +++ b/cinn/hlir/op/nn.cc @@ -44,12 +44,12 @@ std::shared_ptr StrategyForRelu(const framework::NodeAttr &attrs, Expr out = arg_pack[0]; poly::StageMap stages = arg_pack[1]; CHECK(out.as_tensor()); - pe::CudaScheduleInjective(stages[out.as_tensor_ref()], output_shapes.back(), target); + pe::CudaScheduleInjective(stages[out.as_tensor_ref()], output_shapes.front(), target); } else if (target.arch == Target::Arch::X86) { Expr out = arg_pack[0]; poly::StageMap stages = arg_pack[1]; CHECK(out.as_tensor()); - pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()], output_shapes.back(), target); + pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()], output_shapes.front(), target); } *ret = arg_pack; }); @@ -104,12 +104,12 @@ std::shared_ptr StrategyForRelu6(const framework::NodeAttr &attrs, Expr out = arg_pack[0]; poly::StageMap stages = arg_pack[1]; CHECK(out.as_tensor()); - pe::CudaScheduleInjective(stages[out.as_tensor_ref()], output_shapes.back(), target); + pe::CudaScheduleInjective(stages[out.as_tensor_ref()], output_shapes.front(), target); } else if (target.arch == Target::Arch::X86) { Expr out = arg_pack[0]; poly::StageMap stages = arg_pack[1]; CHECK(out.as_tensor()); - pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()], output_shapes.back(), target); + pe::ScheduleInjectiveCPU(stages[out.as_tensor_ref()], output_shapes.front(), target); } *ret = arg_pack; }); @@ -868,18 +868,20 @@ std::shared_ptr StrategyForBatchNorm(const framework::NodeAttr &attr UniqName("BatchNorm_output")); } auto stages = CreateStages({out}); - *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; }); framework::CINNSchedule batchnorm_schedule([=](lang::Args args, lang::RetValue *ret) { CHECK(!args.empty()) << "The input argument of batchnorm schedule is empty! Please check.\n"; CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL); + Expr Out = arg_pack[0]; + poly::StageMap stages = arg_pack[1]; + CHECK(Out.as_tensor()); if (target.arch == Target::Arch::NVGPU) { - Expr Out = arg_pack[0]; - poly::StageMap stages = arg_pack[1]; - CHECK(Out.as_tensor()); - pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(), target); + pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.front(), target); + } else if (target.arch == Target::Arch::X86) { + pe::ScheduleInjectiveCPU(stages[Out.as_tensor_ref()], output_shapes.front(), target); } *ret = arg_pack; }); @@ -1438,62 +1440,6 @@ std::vector> InferLayoutForPool(const std::vector StrategyForSigmoid(const framework::NodeAttr &attrs, - const std::vector &inputs, - const std::vector &out_type, - const std::vector> &output_shapes, - const Target &target) { - framework::CINNCompute sigmoid_compute([](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of sigmoid compute is empty! Please check.\n"; - CINNValuePack a = args[0]; - CHECK(!a.empty()) << "at least one input tensor for sigmoid compute\n"; - Expr A = a[0]; - CHECK(A.as_tensor()); - auto out = pe::Sigmoid(A.as_tensor_ref(), UniqName("Sigmoid_output")); - CHECK(!out.empty()); - auto stages = CreateStages({out}); - *ret = CINNValuePack{{CINNValue(Expr(out.front())), CINNValue(stages)}}; - }); - - framework::CINNSchedule sigmoid_schedule([=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) << "The input argument of sigmoid schedule is empty! Please check.\n"; - CINNValuePack arg_pack = args[0]; - CHECK_EQ(arg_pack.size(), 2UL); - if (target.arch == Target::Arch::NVGPU) { - Expr Out = arg_pack[0]; - poly::StageMap stages = arg_pack[1]; - CHECK(Out.as_tensor()); - pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(), target); - } - *ret = arg_pack; - }); - - auto strategy = std::make_shared(); - CHECK(out_type.size()) << "Out_type of sigmoid op is empty! Please check."; - if (out_type[0] == Float(32)) { - strategy->AddImpl(sigmoid_compute, sigmoid_schedule, "strategy.sigmoid.x86", 1); - } else { - LOG(FATAL) << "Sigmoid op with dtype != float32 is not implemented yet!"; - } - return strategy; -} - -std::vector InferShapeForSigmoid(const std::vector &inputs_shape, - framework::NodeAttr &attrs, - const Target &target) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; - std::vector res{inputs_shape[0]}; - return res; -} - -std::vector InferDtypeForSigmoid(const std::vector &inputs_type, - const framework::NodeAttr &attrs, - const Target &target) { - CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again."; - std::vector res{inputs_type[0]}; - return res; -} - std::shared_ptr StrategyForSoftmax(const framework::NodeAttr &attrs, const std::vector &inputs, const std::vector &out_type, @@ -1645,11 +1591,13 @@ std::shared_ptr StrategyForSlice(const framework::NodeAttr &attrs, CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL) << "The input tensor's size of slice schedule is " << arg_pack.size() << "and it should be equal to 2! Please check."; + Expr Out = arg_pack[0]; + poly::StageMap stages = arg_pack[1]; + CHECK(Out.as_tensor()); if (target.arch == Target::Arch::NVGPU) { - Expr Out = arg_pack[0]; - poly::StageMap stages = arg_pack[1]; - CHECK(Out.as_tensor()); - pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.back(), target); + pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.front(), target); + } else { + pe::ScheduleInjectiveCPU(stages[Out.as_tensor_ref()], output_shapes.front(), target); } *ret = arg_pack; }); @@ -1782,13 +1730,13 @@ std::shared_ptr StrategyForDropoutInfer(const framework::NodeAttr &a CINNValuePack arg_pack = args[0]; CHECK_EQ(arg_pack.size(), 2UL) << "The input tensor's size of dropout_infer schedule is " << arg_pack.size() << "and it should be equal to 2! Please check."; + Expr Out = arg_pack[0]; + poly::StageMap stages = arg_pack[1]; + CHECK(Out.as_tensor()); if (target.arch == Target::Arch::NVGPU) { - Expr Out = arg_pack[0]; - poly::StageMap stages = arg_pack[1]; - CHECK(Out.as_tensor()); - pe::CudaSplitSchedule(stages[Out.as_tensor_ref()], output_shapes.back()); - stages[Out.as_tensor_ref()]->Bind(0, "blockIdx.x"); - stages[Out.as_tensor_ref()]->Bind(1, "threadIdx.x"); + pe::CudaScheduleInjective(stages[Out.as_tensor_ref()], output_shapes.front(), target); + } else { + pe::ScheduleInjectiveCPU(stages[Out.as_tensor_ref()], output_shapes.front(), target); } *ret = arg_pack; }); @@ -1976,19 +1924,6 @@ CINN_REGISTER_HELPER(nn_ops) { .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kOpaque) .set_support_level(4); - CINN_REGISTER_OP(sigmoid) - .describe("Apply sigmoid activation on input tensor. Y = 1 / (1 + Exp(-X))") - .set_num_inputs(1) - .set_num_outputs(1) - .set_attr("CINNStrategy", cinn::hlir::op::StrategyForSigmoid) - .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForSigmoid)) - .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForSigmoid)) -#ifndef CINN_WITH_CUDA - .set_attr("inferlayout", std::function(cinn::hlir::op::InferLayoutForUnary)) -#endif - .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kElemWise) - .set_support_level(4); - CINN_REGISTER_OP(softmax) .describe("This operator implements the softmax layer") .set_num_inputs(1) diff --git a/cinn/hlir/op/op_util.cc b/cinn/hlir/op/op_util.cc new file mode 100644 index 0000000000..85b1ffebaf --- /dev/null +++ b/cinn/hlir/op/op_util.cc @@ -0,0 +1 @@ +#include "cinn/hlir/op/op_util.h" diff --git a/cinn/hlir/op/op_util.h b/cinn/hlir/op/op_util.h new file mode 100644 index 0000000000..3128bf6c77 --- /dev/null +++ b/cinn/hlir/op/op_util.h @@ -0,0 +1,18 @@ +#pragma once +#include +#include + +#include "cinn/ir/ir.h" + +namespace cinn { +namespace hlir { + +template +std::vector ToCinnExprs(const std::vector& args) { + std::vector exprs; + std::transform(args.begin(), args.end(), std::back_inserter(exprs), [](const T& arg) { return Expr(arg); }); + return exprs; +} + +} // namespace hlir +} // namespace cinn diff --git a/cinn/hlir/pass/CMakeLists.txt b/cinn/hlir/pass/CMakeLists.txt index bfd430ee1b..a20a06b940 100755 --- a/cinn/hlir/pass/CMakeLists.txt +++ b/cinn/hlir/pass/CMakeLists.txt @@ -8,6 +8,7 @@ core_gather_srcs(SRCS cc_test(test_opfusion SRCS opfusion_test.cc DEPS cinncore) +cc_test(test_primitive_ops SRCS test_primitive_ops.cc DEPS cinncore) if (NOT WITH_CUDA) cc_test(test_alterlayout SRCS alterlayout_test.cc DEPS cinncore) endif() diff --git a/cinn/hlir/pass/opfusion.cc b/cinn/hlir/pass/opfusion.cc index 2f135b6e21..5e71389daf 100644 --- a/cinn/hlir/pass/opfusion.cc +++ b/cinn/hlir/pass/opfusion.cc @@ -157,7 +157,7 @@ class DomTree { CHECK(graph_node); DomNode* dom_node = new DomNode(); dom_node->ref_node = graph_node; - if (graph_node->inlinks().empty()) { + if (graph_node->inlinks().empty() && graph_node->safe_as()) { CHECK(graph_node->safe_as()); // extern input vars dom_node->parent = nullptr; @@ -265,9 +265,10 @@ class GraphPartition { } return out_shapes; } - bool IsSameOutShape(GraphNode* node1, GraphNode* node2) { + bool VerifyOutShape(GraphNode* node1, GraphNode* node2) { auto out_shape1 = GetOutshape(node1); auto out_shape2 = GetOutshape(node2); + if (out_shape1.size() == 1 || out_shape2.size() == 1) return true; if (out_shape1.size() != out_shape2.size()) return false; VLOG(2) << node1->id() << ", out_shape1: " << utils::Join(out_shape1, ", "); VLOG(2) << node2->id() << ", out_shape2: " << utils::Join(out_shape2, ", "); @@ -313,7 +314,7 @@ class GraphPartition { auto op_node = source->safe_as(); visited_nodes_.clear(); CHECK(source != sink); - if (!IsSameOutShape(source, sink)) return false; + if (!VerifyOutShape(source, sink)) return false; if (op_node) { auto& outlinks = op_node->outlinks_in_order(true); for (int i = 0; i < outlinks.size(); i++) { diff --git a/cinn/hlir/pass/test_primitive_ops.cc b/cinn/hlir/pass/test_primitive_ops.cc new file mode 100644 index 0000000000..04f9190345 --- /dev/null +++ b/cinn/hlir/pass/test_primitive_ops.cc @@ -0,0 +1,92 @@ + +#include + +#include + +#include "cinn/cinn.h" +#include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/hlir/framework/pass.h" +#include "cinn/hlir/op/use_ops.h" +#include "cinn/hlir/pass/use_pass.h" + +DEFINE_string(model_dir, "", ""); + +namespace cinn { +namespace frontend { + +using hlir::framework::Scope; +using utils::Join; + +Target GetTarget() { +#ifdef CINN_WITH_CUDA + return common::DefaultNVGPUTarget(); +#else + return common::DefaultHostTarget(); +#endif +} + +void SetRandData(const hlir::framework::Tensor& tensor, Target target) { +#ifdef CINN_WITH_CUDA + auto* data = tensor->mutable_data(target); + std::vector host_memory(tensor->shape().numel(), 0); + for (float& v : host_memory) { + v = (rand() * 1.f) / RAND_MAX; // All random data + } + CUDA_CALL(cudaMemcpy(reinterpret_cast(data), + host_memory.data(), + tensor->shape().numel() * sizeof(float), + cudaMemcpyHostToDevice)); +#else + auto* data = tensor->mutable_data(target); + for (size_t j = 0; j < tensor->shape().numel(); j++) { + data[j] = (rand() * 1.f) / RAND_MAX; // All random data + } +#endif +} + +// batch_norm primitives +TEST(batch_norm_meta, batch_norm_meta) { + Placeholder A(Float(32), {1, 64, 112, 112}, "A"); + + Placeholder Scale(Float(32), {64}, "Scale"); + Placeholder Bias(Float(32), {64}, "Bias"); + Placeholder Mean(Float(32), {64}, "Mean"); + Placeholder Variance(Float(32), {64}, "Variance"); + + Program program; + std::unordered_map attrs; + attrs["epsilon"] = static_cast(0.001); + + auto a = program.batchnorm(A, Scale, Bias, Mean, Variance, attrs); + + auto b = program.fused_batchnorm_inference(A, Scale, Bias, Mean, Variance, attrs); + + Target target = GetTarget(); + program.SetInputs({A}); + program.Validate(); + LOG(INFO) << "Program:\n" << program; + auto graph = std::make_shared(program, target); + + hlir::framework::ApplyPass(graph.get(), "InferShape"); +#ifndef CINN_WITH_CUDA + hlir::framework::ApplyPass(graph.get(), "AlterLayout"); +#endif + hlir::framework::ApplyPass(graph.get(), "OpFusion"); + auto scope = BuildScope(target, graph); + LOG(INFO) << "graph:\n" << graph->Visualize(); + + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var("A"); + + auto A1 = scope->GetTensor("A"); + SetRandData(A1, target); + + runtime_program->Execute(); +} + +} // namespace frontend +} // namespace cinn diff --git a/cinn/hlir/pe/broadcast.cc b/cinn/hlir/pe/broadcast.cc index cab83191ff..f94b6ef602 100644 --- a/cinn/hlir/pe/broadcast.cc +++ b/cinn/hlir/pe/broadcast.cc @@ -3,6 +3,7 @@ #include #include "cinn/common/ir_util.h" +#include "cinn/hlir/op/op_util.h" #include "cinn/ir/ir_base.h" #include "cinn/ir/ir_operators.h" #include "cinn/lang/builtin.h" @@ -216,6 +217,33 @@ HLIR_IMP_BC_PE(NotEqual, return ir::NE::Make(a, b);); HLIR_IMP_BC_PE(GreaterEqual, return a >= b;); HLIR_IMP_BC_PE(LessEqual, return a <= b;); +Tensor BroadcastTo(const Tensor& A, + const std::vector& out_shape, + const std::vector& broadcast_axes, + const std::string& out_name) { + auto A_shape = A->shape; + CHECK_EQ(A_shape.size(), broadcast_axes.size()) << "broadcast_axes's size should be same with the input shape's size"; + CHECK_GE(out_shape.size(), broadcast_axes.size()) << "broadcast_axes's size should be no more than out_shape's size"; + + return Compute( + ToCinnExprs(out_shape), + [=](const std::vector& indice) { + std::vector broadcast_indice; + for (int i = 0; i < broadcast_axes.size(); i++) { + int a_shape_i = A_shape[i].as_int32(); + CHECK(broadcast_axes[i] >= 0 && broadcast_axes[i] < out_shape.size()) + << "broadcast_axis should be no less than 0 and no more than out_shape's dim. Current broadcast axis is " + << broadcast_axes[i]; + CHECK(a_shape_i == 1 || a_shape_i == out_shape[broadcast_axes[i]]) + << "broadcast_shape should be 1 or same with the target mapping dim, but get " << A_shape[i] << " and " + << out_shape[broadcast_axes[i]]; + broadcast_indice.push_back(indice[broadcast_axes[i]]); + } + return A(broadcast_indice); + }, + out_name); +} + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/pe/broadcast.h b/cinn/hlir/pe/broadcast.h index 414ddb2390..c2d505e718 100644 --- a/cinn/hlir/pe/broadcast.h +++ b/cinn/hlir/pe/broadcast.h @@ -28,10 +28,10 @@ void GetBroadcastOutShape(const std::vector& input_shape1, * shape(A) = (2, 3, 4, 5), shape(B) = (2), with axis=0 * shape(A) = (2, 3, 4, 5), shape(B) = (2, 1), with axis=0 */ -#define HLIR_DCL_BC_PE(name__) \ - ir::Tensor name__(const ir::Tensor& A, \ - const ir::Tensor& B, \ - const std::string& out_name = "T_" #name__ "_out", \ +#define HLIR_DCL_BC_PE(name__) \ + ir::Tensor name__(const ir::Tensor& A, \ + const ir::Tensor& B, \ + const std::string& out_name = common::UniqName("T_" #name__ "_out"), \ const Expr& axis = Expr()); //! Compute A + B with auto-broadcasting. @@ -83,6 +83,11 @@ HLIR_DCL_BC_PE(GreaterEqual); //! Compute A <= B with auto-broadcasting. HLIR_DCL_BC_PE(LessEqual); +ir::Tensor BroadcastTo(const ir::Tensor& A, + const std::vector& out_shape, + const std::vector& broadcast_axes, + const std::string& out_name = common::UniqName("T_broadcast_to_out")); + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/pe/elementwise.cc b/cinn/hlir/pe/elementwise.cc index 84113228ff..9095f9d0da 100644 --- a/cinn/hlir/pe/elementwise.cc +++ b/cinn/hlir/pe/elementwise.cc @@ -4,15 +4,14 @@ #include "cinn/ir/ir_operators.h" #include "cinn/lang/builtin.h" -#include "cinn/lang/compute.h" namespace cinn { namespace hlir { namespace pe { -using cinn::lang::Compute; using ir::Expr; using ir::Tensor; +using lang::Compute; #define HLIR_IMP_UNARY_PE(name__) \ std::vector name__(const Tensor& A, const std::string& output_name) { \ diff --git a/cinn/hlir/pe/elementwise.h b/cinn/hlir/pe/elementwise.h index 6977967f6d..99662d6381 100644 --- a/cinn/hlir/pe/elementwise.h +++ b/cinn/hlir/pe/elementwise.h @@ -4,6 +4,7 @@ #include #include "cinn/ir/ir.h" +#include "cinn/lang/compute.h" namespace cinn { namespace hlir { diff --git a/tests/benchmark/test_all_ops_default.cc b/tests/benchmark/test_all_ops_default.cc index ab274960cf..286f3af2bf 100644 --- a/tests/benchmark/test_all_ops_default.cc +++ b/tests/benchmark/test_all_ops_default.cc @@ -52,8 +52,8 @@ using AttrType = std::variant type = {Float(32)}; -std::vector type1{Float(32), Float(32)}; +std::vector type = {Float(32)}; +std::vector type1 = {Float(32), Float(32)}; std::vector type2 = {Int(32)}; std::vector type3 = {Bool()}; std::vector type4 = {Float(32), Float(32), Float(32), Float(32), Float(32)}; @@ -62,6 +62,14 @@ std::vector type6 = {Float(32), Void()}; std::vector type7 = {Float(32), Float(32), Float(32), Float(32)}; std::vector type8 = {Float(32), Float(32), Float(32)}; +// broadcast_to +std::vector> shapes_broadcast_to = {{32}}; +std::vector out_shape = {100, 32}; +std::vector broadcast_axes = {1}; +std::unordered_map attr_store_broadcast_to = {{"out_shape", out_shape}, + {"broadcast_axes", broadcast_axes}}; +TEST_DEFAULT1(broadcast_to, broadcast_to, type, type, attr_store_broadcast_to) + // add std::vector> shapes_add = {{1024, 1024, 1024}, {1024, 1024, 1024}}; TEST_DEFAULT(elementwise_add, add, type1, type)