Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
add several primitive ops (#438)
Browse files Browse the repository at this point in the history
* meta op demo

* add const_scalar and broadcast_to primitive ops
  • Loading branch information
wenming2014 authored Sep 27, 2021
1 parent 5120147 commit bf94356
Show file tree
Hide file tree
Showing 18 changed files with 722 additions and 190 deletions.
134 changes: 133 additions & 1 deletion cinn/frontend/syntax.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,59 @@ Variable Program::batchnorm(const Variable& a,
return instr.GetOutput(0);
}

template <typename PrimType>
Variable Program::primitive_const_scalar(PrimType value, const std::string& name) {
Instruction instr("const_scalar");
instr.SetInputs({});
instr.SetAttr("value", value);
AppendInstruction(instr);
auto out = instr.GetOutput(0);
out.set_id(name);
auto out_type = type_of<PrimType>();
CHECK(out_type.is_float() || out_type.is_int()) << "no supported type: " << out_type;
out->type = out_type;
return out;
}

Variable Program::primitive_broadcast_to(const Variable& a,
const std::vector<int>& out_shape,
const std::vector<int>& broadcast_axes) {
Instruction instr("broadcast_to");
instr.SetInputs({a});
instr.SetAttr("out_shape", out_shape);
instr.SetAttr("broadcast_axes", broadcast_axes);
AppendInstruction(instr);
return instr.GetOutput(0);
}

Variable Program::fused_batchnorm_inference(const Variable& a,
const Variable& scale,
const Variable& bias,
const Variable& mean,
const Variable& variance,
const std::unordered_map<std::string, attr_t>& attr_store) {
float epsilon = 0.00001f;
if (attr_store.find("epsilon") != attr_store.end()) {
epsilon = std::get<float>(attr_store.at("epsilon"));
}
auto eps_var = primitive_const_scalar<float>(epsilon, common::UniqName("epsilon"));
CHECK(!scale->shape.empty()) << "scale's shape is empty.";
auto broadcast_eps = primitive_broadcast_to(eps_var, scale->shape, {0});
auto var_add_eps = add(variance, broadcast_eps);
auto rsrqt_var = primitive_rsqrt(var_add_eps);
auto new_scale = multiply(rsrqt_var, scale);
auto neg_mean = primitive_negative(mean);
auto new_shift = multiply(new_scale, neg_mean);
auto shift_bias = add(new_shift, bias);
CHECK(!a->shape.empty()) << "variable a's shape is empty.";
auto broadcast_new_scale = primitive_broadcast_to(new_scale, a->shape, {1});
auto broadcast_shift_bias = primitive_broadcast_to(shift_bias, a->shape, {1});
auto temp_out = multiply(broadcast_new_scale, a);
auto bn_out = add(temp_out, broadcast_shift_bias);

return bn_out;
}

Variable Program::scale(const Variable& a, const std::unordered_map<std::string, attr_t>& attr_store) {
Instruction instr("scale", {a});
for (auto& iter : attr_store) {
Expand Down Expand Up @@ -198,6 +251,85 @@ Variable Program::add(const Variable& a, const Variable& b) {
return instr.GetOutput(0);
}

Variable Program::multiply(const Variable& a, const Variable& b) {
Instruction instr("elementwise_mul", {a, b});
AppendInstruction(instr);
return instr.GetOutput(0);
}

#define SYNTAX_PRIM_UNARY_IMPL(name__) \
Variable Program::primitive_##name__(const Variable& a) { \
Instruction instr(#name__, {a}); \
AppendInstruction(instr); \
return instr.GetOutput(0); \
}

SYNTAX_PRIM_UNARY_IMPL(exp);
SYNTAX_PRIM_UNARY_IMPL(erf);
SYNTAX_PRIM_UNARY_IMPL(sqrt);
SYNTAX_PRIM_UNARY_IMPL(log);
SYNTAX_PRIM_UNARY_IMPL(floor);
SYNTAX_PRIM_UNARY_IMPL(ceil);
SYNTAX_PRIM_UNARY_IMPL(round);
SYNTAX_PRIM_UNARY_IMPL(tanh);
SYNTAX_PRIM_UNARY_IMPL(log2);
SYNTAX_PRIM_UNARY_IMPL(log10);
SYNTAX_PRIM_UNARY_IMPL(trunc);
SYNTAX_PRIM_UNARY_IMPL(cos);
SYNTAX_PRIM_UNARY_IMPL(sin);
SYNTAX_PRIM_UNARY_IMPL(cosh);
SYNTAX_PRIM_UNARY_IMPL(tan);
SYNTAX_PRIM_UNARY_IMPL(sinh);
SYNTAX_PRIM_UNARY_IMPL(acos);
SYNTAX_PRIM_UNARY_IMPL(acosh);
SYNTAX_PRIM_UNARY_IMPL(asin);
SYNTAX_PRIM_UNARY_IMPL(asinh);
SYNTAX_PRIM_UNARY_IMPL(atan);
SYNTAX_PRIM_UNARY_IMPL(atanh);

SYNTAX_PRIM_UNARY_IMPL(isnan);
SYNTAX_PRIM_UNARY_IMPL(isfinite);
SYNTAX_PRIM_UNARY_IMPL(isinf);
SYNTAX_PRIM_UNARY_IMPL(bitwise_not);

SYNTAX_PRIM_UNARY_IMPL(negative);
SYNTAX_PRIM_UNARY_IMPL(identity);
SYNTAX_PRIM_UNARY_IMPL(logica_not);
SYNTAX_PRIM_UNARY_IMPL(sign);
SYNTAX_PRIM_UNARY_IMPL(abs);
SYNTAX_PRIM_UNARY_IMPL(rsqrt);

#define SYNTAX_PRIM_BINARY_IMPL(name__) \
Variable Program::primitive_##name__(const Variable& a, const Variable& b) { \
Instruction instr(#name__, {a, b}); \
AppendInstruction(instr); \
return instr.GetOutput(0); \
}

SYNTAX_PRIM_BINARY_IMPL(substract)
SYNTAX_PRIM_BINARY_IMPL(divide)
SYNTAX_PRIM_BINARY_IMPL(floor_divide)
SYNTAX_PRIM_BINARY_IMPL(mod)
SYNTAX_PRIM_BINARY_IMPL(floor_mod)
SYNTAX_PRIM_BINARY_IMPL(max)
SYNTAX_PRIM_BINARY_IMPL(min)
SYNTAX_PRIM_BINARY_IMPL(power)
SYNTAX_PRIM_BINARY_IMPL(logical_and)
SYNTAX_PRIM_BINARY_IMPL(logical_or)
SYNTAX_PRIM_BINARY_IMPL(logical_xor)
SYNTAX_PRIM_BINARY_IMPL(greater)
SYNTAX_PRIM_BINARY_IMPL(less)
SYNTAX_PRIM_BINARY_IMPL(equal)
SYNTAX_PRIM_BINARY_IMPL(not_equal)
SYNTAX_PRIM_BINARY_IMPL(greater_equal)
SYNTAX_PRIM_BINARY_IMPL(less_equal)

SYNTAX_PRIM_BINARY_IMPL(bitwise_or)
SYNTAX_PRIM_BINARY_IMPL(bitwise_xor)
SYNTAX_PRIM_BINARY_IMPL(bitwise_and)
SYNTAX_PRIM_BINARY_IMPL(left_shift)
SYNTAX_PRIM_BINARY_IMPL(right_shift)

Variable Program::elementwise_add(const Variable& a, const Variable& b, int axis) {
Instruction instr("elementwise_add", {a, b});
instr.SetAttr("axis", axis);
Expand Down Expand Up @@ -267,7 +399,7 @@ std::string _Instruction_::debug_string() const {
ss << op_type;
ss << "(";
ss << utils::Join(input_names, ", ");
if (!attrs.empty()) ss << ", ";
if (!attrs.empty() && !input_names.empty()) ss << ", ";

std::vector<std::string> attr_strs;
for (auto& attr : attrs) {
Expand Down
87 changes: 87 additions & 0 deletions cinn/frontend/syntax.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,12 @@ struct Program {
: instrs_(std::move(instrs)), inputs_(std::move(inputs)) {}

void SetInputs(const std::vector<Variable>& xs);

/**
* create scalar with the specific value and type
*/
template <typename PrimType>
Variable primitive_const_scalar(PrimType value, const std::string& name);
/**
* Add two variables.
*
Expand All @@ -178,6 +184,7 @@ struct Program {
* @return The result.
*/
Variable add(const Variable& a, const Variable& b);
Variable multiply(const Variable& a, const Variable& b);

/**
* Multiply two matrix.
Expand All @@ -190,6 +197,76 @@ struct Program {
Variable mulbias(
const Variable& a, const Variable& b, const Variable& c, int x_num_col_dims = 1, int y_num_col_dims = 1);

#define SYNTAX_PRIM_UNARY_DECL(name__) Variable primitive_##name__(const Variable& a);

SYNTAX_PRIM_UNARY_DECL(exp);
SYNTAX_PRIM_UNARY_DECL(erf);
SYNTAX_PRIM_UNARY_DECL(sqrt);
SYNTAX_PRIM_UNARY_DECL(log);
SYNTAX_PRIM_UNARY_DECL(floor);
SYNTAX_PRIM_UNARY_DECL(ceil);
SYNTAX_PRIM_UNARY_DECL(round);
SYNTAX_PRIM_UNARY_DECL(tanh);
SYNTAX_PRIM_UNARY_DECL(log2);
SYNTAX_PRIM_UNARY_DECL(log10);
SYNTAX_PRIM_UNARY_DECL(trunc);
SYNTAX_PRIM_UNARY_DECL(cos);
SYNTAX_PRIM_UNARY_DECL(sin);
SYNTAX_PRIM_UNARY_DECL(cosh);
SYNTAX_PRIM_UNARY_DECL(tan);
SYNTAX_PRIM_UNARY_DECL(sinh);
SYNTAX_PRIM_UNARY_DECL(acos);
SYNTAX_PRIM_UNARY_DECL(acosh);
SYNTAX_PRIM_UNARY_DECL(asin);
SYNTAX_PRIM_UNARY_DECL(asinh);
SYNTAX_PRIM_UNARY_DECL(atan);
SYNTAX_PRIM_UNARY_DECL(atanh);

SYNTAX_PRIM_UNARY_DECL(isnan);
SYNTAX_PRIM_UNARY_DECL(isfinite);
SYNTAX_PRIM_UNARY_DECL(isinf);
SYNTAX_PRIM_UNARY_DECL(bitwise_not);

SYNTAX_PRIM_UNARY_DECL(negative);
SYNTAX_PRIM_UNARY_DECL(identity);
SYNTAX_PRIM_UNARY_DECL(logica_not);
SYNTAX_PRIM_UNARY_DECL(sign);
SYNTAX_PRIM_UNARY_DECL(abs);
SYNTAX_PRIM_UNARY_DECL(rsqrt);

#define SYNTAX_PRIM_BINARY_DECL(name__) Variable primitive_##name__(const Variable& a, const Variable& b);
SYNTAX_PRIM_BINARY_DECL(substract)
SYNTAX_PRIM_BINARY_DECL(divide)
SYNTAX_PRIM_BINARY_DECL(floor_divide)
SYNTAX_PRIM_BINARY_DECL(mod)
SYNTAX_PRIM_BINARY_DECL(floor_mod)
SYNTAX_PRIM_BINARY_DECL(max)
SYNTAX_PRIM_BINARY_DECL(min)
SYNTAX_PRIM_BINARY_DECL(power)
SYNTAX_PRIM_BINARY_DECL(logical_and)
SYNTAX_PRIM_BINARY_DECL(logical_or)
SYNTAX_PRIM_BINARY_DECL(logical_xor)
SYNTAX_PRIM_BINARY_DECL(greater)
SYNTAX_PRIM_BINARY_DECL(less)
SYNTAX_PRIM_BINARY_DECL(equal)
SYNTAX_PRIM_BINARY_DECL(not_equal)
SYNTAX_PRIM_BINARY_DECL(greater_equal)
SYNTAX_PRIM_BINARY_DECL(less_equal)

SYNTAX_PRIM_BINARY_DECL(bitwise_or)
SYNTAX_PRIM_BINARY_DECL(bitwise_xor)
SYNTAX_PRIM_BINARY_DECL(bitwise_and)
SYNTAX_PRIM_BINARY_DECL(left_shift)
SYNTAX_PRIM_BINARY_DECL(right_shift)

// broadcast one operand to the target shape
// broadcast axes: the target axis which a's ith axis is mapped to
// Notes: a's dim should be one or same with the output dim mapped to.
// e.g. if a[64] broadcasts to out[1, 64, 112, 112], then out_shape is {1, 64, 112, 112} and broadcast_axes are {1}
Variable primitive_broadcast_to(const Variable& a,
const std::vector<int>& out_shape,
const std::vector<int>& broadcast_axes);

/**
* Add two tensors element-wise.
*/
Expand Down Expand Up @@ -245,6 +322,16 @@ struct Program {
const Variable& variance,
const std::unordered_map<std::string, attr_t>& attr_store);

/**
* batchnorm composed of primitive ops
*/
Variable fused_batchnorm_inference(const Variable& a,
const Variable& scale,
const Variable& bias,
const Variable& mean,
const Variable& variance,
const std::unordered_map<std::string, attr_t>& attr_store);

Variable scale(const Variable& a, const std::unordered_map<std::string, attr_t>& attr_store);

Variable softmax(const Variable& a, const std::unordered_map<std::string, attr_t>& attr_store);
Expand Down
2 changes: 1 addition & 1 deletion cinn/hlir/framework/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ Graph::Graph(const frontend::Program& prog, const Target& target) {
graph_node->as<NodeData>()->LinkTo(node_tmp);
}
}
int out_idx = 0;
for (auto& output_v : temp->outputs) {
int out_idx = 0;
auto* output_data = new NodeData(node_ptr, out_idx++, 0, output_v->id);
node_tmp->LinkTo(output_data);
this->RegisterNode(output_v->id, output_data);
Expand Down
32 changes: 25 additions & 7 deletions cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,19 @@ std::vector<ir::LoweredFunc> GraphCompiler::GetOpFunc(const Node* node) {
return func;
}

// get the most complex op's index in the fused groups according to the OpPattern. If the OpPattern is same, we will take the latter.
int GetMasterRefNode(const std::vector<Node*>& nodes) {
auto& op_pattern_dict = Operator::GetAttrs<OpPatternKind>("OpPattern");
int master_index = 0;
int master_pattern = op_pattern_dict[nodes[0]->op()];
for (int i = 1; i < nodes.size(); i++) {
int pattern = op_pattern_dict[nodes[i]->op()];
master_index = pattern >= master_pattern ? i : master_index;
}
VLOG(3) << "master_index: " << master_index << ", master op: " << nodes[master_index]->op()->name;
return master_index;
}

std::vector<ir::LoweredFunc> GraphCompiler::GetOpFunc(const std::vector<Node*>& nodes) {
CHECK_GT(nodes.size(), 1) << "fuse nodes number must be greater than 1";
auto& strategy = Operator::GetAttrs<StrategyFunction>("CINNStrategy");
Expand All @@ -133,7 +146,8 @@ std::vector<ir::LoweredFunc> GraphCompiler::GetOpFunc(const std::vector<Node*>&
std::unordered_set<NodeData*> in_vars;
std::unordered_set<NodeData*> out_vars;
std::unordered_map<NodeData*, Expr> temp_var_map;
ir::Tensor first_out_tensor;
ir::Tensor master_out_tensor;
int master_index = GetMasterRefNode(nodes);
for (auto& node : nodes) {
std::vector<ir::Tensor> temp_inputs;
std::vector<common::CINNValue> cinn_inputs;
Expand Down Expand Up @@ -181,12 +195,12 @@ std::vector<ir::LoweredFunc> GraphCompiler::GetOpFunc(const std::vector<Node*>&
OpStrategy::SelectImpl(strategy[node->op()](node->attrs, temp_inputs, out_types, output_shapes, target_));

common::CINNValuePack C = impl->fcompute(common::CINNValuePack{cinn_inputs});
if (index == 0) {
// use the first op's schedule as the fused ops' schedule as complex op like conv appear in the first.
if (index == master_index) {
// use the most complex op's schedule as the fused ops' schedule.
C = impl->fschedule(C);
CHECK(!C.empty());
Expr out = C[0];
first_out_tensor = out.as_tensor_ref();
Expr out = C[0];
master_out_tensor = out.as_tensor_ref();
}
CHECK_GE(C.size(), 2);
CHECK_LE(C.size() - 1, node->outlinks_in_order().size());
Expand Down Expand Up @@ -237,8 +251,10 @@ std::vector<ir::LoweredFunc> GraphCompiler::GetOpFunc(const std::vector<Node*>&
inputs.insert(inputs.end(), outputs.begin(), outputs.end());

ir::Tensor final_out_tensor = outputs.front();
stages[final_out_tensor]->CopyTransform(stages[first_out_tensor]);
stages[final_out_tensor]->CopyLoopInfo(stages[first_out_tensor]);
if (final_out_tensor->name != master_out_tensor->name) {
stages[final_out_tensor]->CopyTransform(stages[master_out_tensor]);
stages[final_out_tensor]->CopyLoopInfo(stages[master_out_tensor]);
}

for (auto& s : stages) {
auto& compute_ats = s.second->GetComputeAts();
Expand All @@ -255,7 +271,9 @@ std::vector<ir::LoweredFunc> GraphCompiler::GetOpFunc(const std::vector<Node*>&
new_relation.level = old_relation.level;

compute_ats.clear();
CHECK(new_relation.IsCompatible(s.second.get())) << "new computeAt should be compatible";
compute_ats[new_stage->id()] = new_relation;
break;
}
}
}
Expand Down
1 change: 1 addition & 0 deletions cinn/hlir/op/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ core_gather_srcs(SRCS
broadcast.cc
transform.cc
elementwise.cc
op_util.cc
)

cc_test(test_cinn_op_broadcast SRCS op_broadcast_test.cc DEPS cinncore)
Expand Down
Loading

0 comments on commit bf94356

Please sign in to comment.