From f52ef4a47139b60bbd93d0f3e5cf97bac169c65e Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Mon, 7 Sep 2020 16:47:39 +0800 Subject: [PATCH] Add FC frontend program unittest (#206) --- cinn/frontend/CMakeLists.txt | 5 + cinn/frontend/paddle/cpp/op_desc.cc | 1 + cinn/frontend/paddle/pb/block_desc.h | 1 + cinn/frontend/paddle/pb/op_desc.h | 1 + cinn/frontend/syntax.cc | 103 ++++++++++++----- cinn/frontend/syntax.h | 96 ++++++++++----- cinn/frontend/syntax_test.cc | 122 +++++++++++++++++++- cinn/hlir/framework/graph.cc | 2 +- cinn/hlir/framework/graph.h | 2 +- cinn/hlir/framework/graph_compiler.cc | 5 +- cinn/hlir/framework/graph_compiler.h | 8 +- cinn/hlir/framework/infershape_pass_test.cc | 6 +- cinn/hlir/framework/node.cc | 13 ++- cinn/hlir/framework/op_strategy.h | 1 + cinn/hlir/framework/pass.cc | 1 + cinn/hlir/framework/schedule.h | 1 + cinn/hlir/framework/scope.cc | 10 +- cinn/hlir/framework/scope.h | 8 +- cinn/hlir/framework/tensor.h | 1 + cinn/hlir/op/broadcast.cc | 5 +- cinn/hlir/op/nn.cc | 3 + cinn/hlir/op/transform.cc | 26 ++++- cinn/hlir/pass/infershape.cc | 12 +- cinn/hlir/pe/broadcast.cc | 9 +- cinn/hlir/pe/transform.cc | 15 ++- cinn/ir/tensor.cc | 3 +- cinn/pybind/frontend.cc | 3 +- cinn/pybind/runtime.cc | 12 +- cinn/runtime/cpu/host_intrinsics.cc | 5 +- cinn/utils/string.cc | 15 --- cinn/utils/string.h | 9 +- python/tests/test_op_broadcast.py | 1 + 32 files changed, 379 insertions(+), 126 deletions(-) diff --git a/cinn/frontend/CMakeLists.txt b/cinn/frontend/CMakeLists.txt index c38c0d1c5c..ba023964d9 100644 --- a/cinn/frontend/CMakeLists.txt +++ b/cinn/frontend/CMakeLists.txt @@ -2,7 +2,12 @@ set(srcs syntax.cc ) +if(NOT WITH_CUDA) cc_test(test_frontend_syntax SRCS syntax_test.cc DEPS core) +else() +nv_test(test_frontend_syntax SRCS syntax_test.cc DEPS core) +endif() + add_subdirectory(paddle) diff --git a/cinn/frontend/paddle/cpp/op_desc.cc b/cinn/frontend/paddle/cpp/op_desc.cc index 69ec6fa820..eed71d4a64 100644 --- a/cinn/frontend/paddle/cpp/op_desc.cc +++ b/cinn/frontend/paddle/cpp/op_desc.cc @@ -1,4 +1,5 @@ #include "cinn/frontend/paddle/cpp/op_desc.h" + #include namespace cinn::frontend::paddle::cpp { diff --git a/cinn/frontend/paddle/pb/block_desc.h b/cinn/frontend/paddle/pb/block_desc.h index 5b05c62b5f..69b48ad845 100644 --- a/cinn/frontend/paddle/pb/block_desc.h +++ b/cinn/frontend/paddle/pb/block_desc.h @@ -1,5 +1,6 @@ #pragma once #include + #include "cinn/frontend/paddle/cpp/desc_api.h" #include "cinn/frontend/paddle/framework.pb.h" diff --git a/cinn/frontend/paddle/pb/op_desc.h b/cinn/frontend/paddle/pb/op_desc.h index 7ee6619b71..3c9bee06a0 100644 --- a/cinn/frontend/paddle/pb/op_desc.h +++ b/cinn/frontend/paddle/pb/op_desc.h @@ -1,5 +1,6 @@ #pragma once #include + #include "cinn/frontend/paddle/cpp/op_desc.h" #include "cinn/frontend/paddle/framework.pb.h" diff --git a/cinn/frontend/syntax.cc b/cinn/frontend/syntax.cc index 4fa2b7095e..a14f667004 100644 --- a/cinn/frontend/syntax.cc +++ b/cinn/frontend/syntax.cc @@ -1,4 +1,6 @@ #include "cinn/frontend/syntax.h" + +#include "cinn/frontend/paddle/model_parser.h" #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/utils/string.h" @@ -14,33 +16,15 @@ void Instruction::PrepareOutputs() { } } -Instruction::Instruction(std::string_view op_type, Program* parent) +Instruction::Instruction(std::string_view op_type, const std::vector& inputs, Program* parent) : common::Shared<_Instruction_>(common::make_shared<_Instruction_>()) { get()->op_type = op_type; get()->parent_program = parent; + get()->inputs = inputs; PrepareOutputs(); } -Placeholder::operator Variable() { - Variable var(id()); - var->shape = shape(); - var->type = type_; - return var; -} - -Variable Program::add(const Variable& a, const Variable& b) { - Instruction instr("elementwise_add"); - instr.SetInputs({a, b}); - AddInstruction(instr); - return instr.GetOutputs()[0]; -} - -Variable Program::relu(const Variable& a) { - Instruction instr("relu"); - instr.SetInputs({a}); - AddInstruction(instr); - return instr.GetOutputs()[0]; -} +Placeholder::operator Variable() const { return var_; } std::vector Program::conv2d( const Variable& a, @@ -51,7 +35,7 @@ std::vector Program::conv2d( for (auto& iter : attr_store) { instr.SetAttr(iter.first, iter.second); } - AddInstruction(instr); + AppendInstruction(instr); return instr.GetOutputs(); } @@ -63,18 +47,18 @@ Variable Program::batchnorm(const Variable& a, for (auto& iter : attr_store) { instr.SetAttr(iter.first, iter.second); } - AddInstruction(instr); + AppendInstruction(instr); return instr.GetOutputs()[0]; } Instruction& Program::operator[](size_t i) { - CHECK_LT(i, instrs.size()); - return instrs[i]; + CHECK_LT(i, instrs_.size()); + return instrs_[i]; } const Instruction& Program::operator[](size_t i) const { - CHECK_LT(i, instrs.size()); - return instrs[i]; + CHECK_LT(i, instrs_.size()); + return instrs_[i]; } std::ostream& operator<<(std::ostream& os, const Variable& x) { @@ -95,7 +79,68 @@ std::ostream& operator<<(std::ostream& os, const Instruction& instr) { return os; } +// Add an Instruction to a program given a Paddle-format \p op_desc. +void ProgramAddOp(Program* program, const paddle::cpp::OpDesc& op_desc) {} + +void LoadPaddleProgram(const std::string& model_dir, bool is_combined) { + hlir::framework::Scope scope; + paddle::cpp::ProgramDesc program_desc; + paddle::LoadModelPb(model_dir, "__model__", "", &scope, &program_desc, is_combined); + CHECK_EQ(program_desc.BlocksSize(), 1) << "CINN can only support the model with a single block"; + auto* block_desc = program_desc.GetBlock(0); + for (int i = 0; i < block_desc->OpsSize(); i++) { + auto* op_desc = block_desc->GetOp(i); + } +} + +void Program::SetInputs(const std::vector& xs) { + CHECK(!xs.empty()) << "At least one input is needed for a program!"; + for (int i = 0; i < xs.size(); i++) { + CHECK(!xs[i]->shape.empty()) << "Found " << i << "-th input's shape is not set yet"; + CHECK(!xs[i]->type.is_unk()) << "Found " << i << "-th input's type is not set yet"; + inputs_.push_back(xs[i]); + } +} + +void Program::Validate() const { + CHECK(!inputs_.empty()) << "Inputs of the program is not set yet"; + CHECK(!instrs_.empty()) << "No instruction is added yet"; +} + +Variable Program::add(const Variable& a, const Variable& b) { + Instruction instr("elementwise_add", {a, b}); + AppendInstruction(instr); + return instr.GetOutput(0); +} + +Variable Program::elementwise_add(const Variable& a, const Variable& b, int axis) { + Instruction instr("elementwise_add", {a, b}); + instr.SetAttr("axis", axis); + AppendInstruction(instr); + return instr.GetOutput(0); +} + +Variable Program::relu(const Variable& a) { + Instruction instr("relu", {a}); + AppendInstruction(instr); + return instr.GetOutput(0); +} + +Variable Program::relu6(const Variable& a) { + Instruction instr("relu6", {a}); + AppendInstruction(instr); + return instr.GetOutput(0); +} +Variable Program::mul( + const Variable& a, const Variable& b, bool trans_a, bool trans_b, int x_num_col_dims, int y_num_col_dims) { + Instruction instr("mul", {a, b}); + instr.SetAttr("trans_a", trans_a); + instr.SetAttr("trans_b", trans_b); + instr.SetAttr("x_num_col_dims", x_num_col_dims); + instr.SetAttr("y_num_col_dims", y_num_col_dims); + AppendInstruction(instr); + return instr.GetOutput(0); +} + } // namespace frontend } // namespace cinn - -// CINN_REGISTRY_ENABLE(cinn::hlir::framework::Operator); diff --git a/cinn/frontend/syntax.h b/cinn/frontend/syntax.h index ddecefbce6..a6c1964a8a 100644 --- a/cinn/frontend/syntax.h +++ b/cinn/frontend/syntax.h @@ -9,6 +9,7 @@ #include #include +#include "cinn/common/common.h" #include "cinn/common/context.h" #include "cinn/common/object.h" #include "cinn/common/type.h" @@ -20,34 +21,6 @@ namespace frontend { struct Program; struct Variable; -/** - * Placeholder is the fed slot of a computation. - */ -class Placeholder { - public: - /** - * @param type Type of the fed - * @param shape Shape of the fed - * @param id ID of the fed - */ - Placeholder(const common::Type& type, const std::vector& shape, std::string_view id = "") - : type_(type), shape_(shape), id_(id.empty() ? common::Context::Global().NewName("placeholder") : id) {} - - const std::vector& shape() const { return shape_; } - - std::string_view id() const { return id_; } - - operator Variable(); - - Program* parent_program() { return parent_program_; } - - private: - common::Type type_; - std::string id_{}; - std::vector shape_; - Program* parent_program_{}; -}; - struct _Variable_ : public common::Object { std::string id; common::Type type; @@ -73,6 +46,38 @@ struct Variable : public common::Shared<_Variable_> { const _Variable_* operator->() const { return get(); } }; +/** + * Placeholder is the fed slot of a computation. + */ +class Placeholder { + public: + /** + * @param type Type of the fed + * @param shape Shape of the fed + * @param id ID of the fed + */ + Placeholder(const common::Type& type, const std::vector& shape, std::string_view id = "") + : id_(id.empty() ? common::Context::Global().NewName("placeholder") : id), var_{id} { + var_->shape = shape; + var_->type = type; + } + + const std::vector& shape() const { return var_->shape; } + + Type type() const { return var_->type; } + + std::string_view id() const { return id_; } + + operator Variable() const; + + Program* parent_program() { return parent_program_; } + + private: + Variable var_; + std::string id_{}; + Program* parent_program_{}; +}; + /** * Data of a Instruction. */ @@ -94,7 +99,7 @@ struct _Instruction_ : public common::Object { * Instruction is the basic computational unit of a Program, similar to the operator concept in a DNN platform. */ struct Instruction : public common::Shared<_Instruction_> { - explicit Instruction(std::string_view op_type, Program* parent = nullptr); + explicit Instruction(std::string_view op_type, const std::vector& inputs = {}, Program* parent = nullptr); /** * Set the inputs of the instruction. @@ -102,6 +107,10 @@ struct Instruction : public common::Shared<_Instruction_> { */ void SetInputs(const std::vector& vars) { get()->inputs = vars; } const std::vector& GetOutputs() const { return get()->outputs; } + const Variable& GetOutput(size_t offset) const { + CHECK_LT(offset, get()->outputs.size()); + return GetOutputs()[offset]; + } /** * Set an attribute of the instruction. @@ -136,6 +145,7 @@ struct Instruction : public common::Shared<_Instruction_> { * Program is a representation of a computation. */ struct Program { + void SetInputs(const std::vector& xs); /** * Add two variables. * @@ -145,6 +155,21 @@ struct Program { */ Variable add(const Variable& a, const Variable& b); + /** + * Multiply two matrix. + */ + Variable mul(const Variable& a, + const Variable& b, + bool trans_a = false, + bool trans_b = false, + int x_num_col_dims = -1, + int y_num_col_dims = -1); + + /** + * Add two tensors element-wise. + */ + Variable elementwise_add(const Variable& a, const Variable& b, int axis = 0); + /** * Apply Rectified Linear Unit on input Variable. * Actually apply: outupt = max(input,0) @@ -153,6 +178,7 @@ struct Program { * @return The result. */ Variable relu(const Variable& a); + Variable relu6(const Variable& a); /** * The convolution2D layer calculates the output based on the input, filter @@ -193,14 +219,20 @@ struct Program { * Get number of instructions in the program. * @return */ - inline size_t size() const { return instrs.size(); } + inline size_t size() const { return instrs_.size(); } + + void Validate() const; private: - void AddInstruction(const Instruction& other) { instrs.push_back(other); } + void AppendInstruction(const Instruction& other) { instrs_.push_back(other); } + + std::vector instrs_; - std::vector instrs; + std::vector inputs_; }; +void LoadPaddleProgram(const std::string& model_dir, bool is_combined); + std::ostream& operator<<(std::ostream& os, const Variable& x); std::ostream& operator<<(std::ostream& os, const Instruction& instr); diff --git a/cinn/frontend/syntax_test.cc b/cinn/frontend/syntax_test.cc index 17b6f4ec6e..fdba70930c 100644 --- a/cinn/frontend/syntax_test.cc +++ b/cinn/frontend/syntax_test.cc @@ -2,28 +2,138 @@ #include +#include + #include "cinn/cinn.h" +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/hlir/framework/pass.h" #include "cinn/hlir/op/use_ops.h" +#include "cinn/hlir/pass/use_pass.h" namespace cinn { namespace frontend { -TEST(syntax, basic) { +std::unique_ptr CreateAddProgram() { const int M = 32; const int N = 24; Placeholder a(Float(32), {M, N}); Placeholder b(Float(32), {M, N}); - Program program; + std::unique_ptr program(new Program); + + auto c = program->add(a, b); + auto d = program->add(a, c); + + program->SetInputs({a, b}); + program->Validate(); - auto c = program.add(a, b); - auto d = program.add(a, c); + return program; +} +TEST(syntax, basic) { + auto program = CreateAddProgram(); // output program - for (int i = 0; i < program.size(); i++) { - LOG(INFO) << "instruction: " << program[i]; + for (int i = 0; i < program->size(); i++) { + LOG(INFO) << "instruction: " << (*program)[i]; + } +} + +void SetRandData(hlir::framework::Tensor* tensor, Target target) { + auto* data = tensor->mutable_data(target); + for (size_t j = 0; j < tensor->shape().numel(); j++) { + unsigned int seed = j; + data[j] = (rand_r(&seed) * 1.f) / RAND_MAX; // All random data } } +TEST(syntax, program_execute_multi_elementwise_add) { + auto program = CreateAddProgram(); + auto graph = std::make_shared(*program); + LOG(INFO) << "graph:\n" << graph->Visualize(); + + hlir::framework::ApplyPass(graph.get(), "InferShape"); + Target target = common::DefaultHostTarget(); + auto scope = BuildScope(target, graph); + + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var("a"); + scope->Var("b"); + + auto A = scope->GetTensor("a"); + auto B = scope->GetTensor("b"); + SetRandData(A, target); + SetRandData(B, target); + + runtime_program->Execute(); +} + +TEST(syntax, program_execute_multi_elementwise_add2) { + auto program = CreateAddProgram(); + auto graph = std::make_shared(*program); + LOG(INFO) << "graph:\n" << graph->Visualize(); + + hlir::framework::ApplyPass(graph.get(), "InferShape"); + Target target = common::DefaultHostTarget(); + auto scope = BuildScope(target, graph); + + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var("a"); + scope->Var("b"); + + auto A = scope->GetTensor("a"); + auto B = scope->GetTensor("b"); + SetRandData(A, target); + SetRandData(B, target); + + runtime_program->Execute(); +} + +TEST(syntax, program_execute_fc) { + const int B = 10; // batch size + const int M = 32; + const int K = 18; + const int N = 24; + + Placeholder a(Float(32), {B, M, K}, "a"); + Placeholder w(Float(32), {K, N}, "w"); // weight + Placeholder b(Float(32), {N}, "b"); // bias + + Program program; + auto mul_out = program.mul(a, w, false /*trans_a*/, false /*trans_b*/, 2, 1); + auto add_out = program.add(mul_out, b); + program.SetInputs({a, w, b}); + program.Validate(); + + auto graph = std::make_shared(program); + + hlir::framework::ApplyPass(graph.get(), "InferShape"); + Target target = common::DefaultHostTarget(); + auto scope = BuildScope(target, graph); + + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + + scope->Var(std::string(a.id())); + scope->Var(std::string(w.id())); + scope->Var(std::string(b.id())); + scope->Var(std::string(mul_out->id)); + + auto at = scope->GetTensor(std::string(a.id())); + auto wt = scope->GetTensor(std::string(w.id())); + auto bt = scope->GetTensor(std::string(b.id())); + auto fake_outt = scope->GetTensor(std::string(mul_out->id)); + SetRandData(at, target); + SetRandData(wt, target); + SetRandData(bt, target); + SetRandData(fake_outt, target); + + runtime_program->Execute(); +} + } // namespace frontend } // namespace cinn diff --git a/cinn/hlir/framework/graph.cc b/cinn/hlir/framework/graph.cc index 0d3ce60e5e..aa1d4e8fd4 100644 --- a/cinn/hlir/framework/graph.cc +++ b/cinn/hlir/framework/graph.cc @@ -4,7 +4,7 @@ namespace cinn { namespace hlir { namespace framework { -Graph::Graph(frontend::Program prog) { +Graph::Graph(const frontend::Program& prog) { std::unordered_map> shape_dict; std::unordered_map dtype_dict; int counter = 0; diff --git a/cinn/hlir/framework/graph.h b/cinn/hlir/framework/graph.h index f06de1456d..bec135a20c 100644 --- a/cinn/hlir/framework/graph.h +++ b/cinn/hlir/framework/graph.h @@ -19,7 +19,7 @@ namespace framework { */ class Graph : public cinn::common::Graph { public: - explicit Graph(frontend::Program prog); + explicit Graph(const frontend::Program& prog); /** \brief outputs of the computation graph. */ std::vector outputs; diff --git a/cinn/hlir/framework/graph_compiler.cc b/cinn/hlir/framework/graph_compiler.cc index eb23093c07..c28d54ffad 100644 --- a/cinn/hlir/framework/graph_compiler.cc +++ b/cinn/hlir/framework/graph_compiler.cc @@ -85,12 +85,12 @@ ir::LoweredFunc GraphCompiler::GetOpFunc(const Node* node) { C = impl->fschedule(C); for (int i = 0; i < C.get()->size() - 1; i++) { ir::Expr temp = C[i]; - stages->Insert(temp.as_tensor_ref(), ir::CreateStage(temp.as_tensor_ref()).get()); + stages->InsertLazily(temp.as_tensor_ref()); inputs.push_back(temp.as_tensor_ref()); } auto func = Lower(GenOpFuncName(node), stages, inputs); - LOG(INFO) << "The function of node [" << node->attrs.node_name << "] is: " << func; + LOG(INFO) << "The function of node [" << node->attrs.node_name << "] is:\n" << func; return func; } @@ -121,6 +121,7 @@ std::shared_ptr BuildScope(Target target, const std::shared_ptr& g for (auto& shape_dim : iter.second) { shape.push_back(Shape::dim_t(shape_dim)); } + VLOG(3) << "Tensor [" << iter.first << "] resize to " << utils::Join(shape, ","); tensor.Resize(Shape{shape}); CHECK_EQ(dtype_dict.at(iter.first), Float(32)) << "The dtype of node " << iter.first << " is not float! Other dtype is not implemented yet."; diff --git a/cinn/hlir/framework/graph_compiler.h b/cinn/hlir/framework/graph_compiler.h index 7be35e17f3..31f2b70832 100644 --- a/cinn/hlir/framework/graph_compiler.h +++ b/cinn/hlir/framework/graph_compiler.h @@ -35,7 +35,11 @@ class Program { * Execute the program -- that is running all the instructions inside it. */ void Execute() { - for (auto& ins : instrs_) ins->Run(); + for (auto& ins : instrs_) { + for (int i = 0; i < 100; i++) { + ins->Run(); + } + } } /** @@ -43,6 +47,8 @@ class Program { */ size_t size() const { return instrs_.size(); } + ~Program() {} + private: // We need to hold scope to assure tensors alive used in instructions. std::shared_ptr scope_; diff --git a/cinn/hlir/framework/infershape_pass_test.cc b/cinn/hlir/framework/infershape_pass_test.cc index 3858a253c8..ba448795d5 100644 --- a/cinn/hlir/framework/infershape_pass_test.cc +++ b/cinn/hlir/framework/infershape_pass_test.cc @@ -46,10 +46,12 @@ TEST(Operator, GetAttrs) { auto d = prog.add(c, b); auto e = prog.add(c, d); ASSERT_EQ(prog.size(), 3UL); - std::shared_ptr g(new Graph(prog)); + auto g = std::make_shared(prog); ApplyPass(g.get(), "InferShape"); + Target target(Target::OS::Linux, Target::Arch::X86, Target::Bit::k64, {}); - std::shared_ptr scope = BuildScope(target, g); + auto scope = BuildScope(target, g); + GraphCompiler gc(target, scope, g); std::unique_ptr program = gc.Build(); diff --git a/cinn/hlir/framework/node.cc b/cinn/hlir/framework/node.cc index 5afe8b2f4d..040d70fe56 100644 --- a/cinn/hlir/framework/node.cc +++ b/cinn/hlir/framework/node.cc @@ -1,4 +1,5 @@ #include "cinn/hlir/framework/node.h" + #include namespace cinn { @@ -16,15 +17,15 @@ std::tuple NodeData::LinkTo(Node* other) namespace { struct PyBindNodeAttrVisitor { - std::stringstream &out; - PyBindNodeAttrVisitor(std::stringstream &out) : out(out) {} + std::stringstream& out; + explicit PyBindNodeAttrVisitor(std::stringstream& out) : out(out) {} void operator()(int v) { out << "int: " << v; } void operator()(float v) { out << "float: " << v; } void operator()(bool v) { out << "bool: " << v; } - void operator()(const std::string &v) { out << "string: " << v; } + void operator()(const std::string& v) { out << "string: " << v; } #define VISIT_ELEMENTS(T__) \ - void operator()(const std::vector &vs) { \ + void operator()(const std::vector& vs) { \ if (vs.empty()) return; \ for (int i = 0; i < vs.size() - 1; i++) out << vs[i] << ","; \ out << vs.back(); \ @@ -37,10 +38,10 @@ struct PyBindNodeAttrVisitor { } // namespace -std::ostream &operator<<(std::ostream &os, const NodeAttr &node_attr) { +std::ostream& operator<<(std::ostream& os, const NodeAttr& node_attr) { std::stringstream ss; ss << "NodeAttr:\n"; - for (auto &item : node_attr.attr_store) { + for (auto& item : node_attr.attr_store) { std::stringstream os; PyBindNodeAttrVisitor visitor(os); std::visit(visitor, item.second); diff --git a/cinn/hlir/framework/op_strategy.h b/cinn/hlir/framework/op_strategy.h index 45aaa51079..9d0c838cef 100644 --- a/cinn/hlir/framework/op_strategy.h +++ b/cinn/hlir/framework/op_strategy.h @@ -3,6 +3,7 @@ #include #include #include + #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/schedule.h" #include "cinn/lang/packed_func.h" diff --git a/cinn/hlir/framework/pass.cc b/cinn/hlir/framework/pass.cc index 45c6b4a9d3..0f4fd10784 100644 --- a/cinn/hlir/framework/pass.cc +++ b/cinn/hlir/framework/pass.cc @@ -1,4 +1,5 @@ #include "cinn/hlir/framework/pass.h" + #include "cinn/hlir/pass/use_pass.h" namespace cinn { diff --git a/cinn/hlir/framework/schedule.h b/cinn/hlir/framework/schedule.h index 24f5616638..85feb4358c 100644 --- a/cinn/hlir/framework/schedule.h +++ b/cinn/hlir/framework/schedule.h @@ -2,6 +2,7 @@ #include #include #include + #include "cinn/cinn.h" #include "cinn/ir/tensor.h" diff --git a/cinn/hlir/framework/scope.cc b/cinn/hlir/framework/scope.cc index 958d839728..1de611154f 100644 --- a/cinn/hlir/framework/scope.cc +++ b/cinn/hlir/framework/scope.cc @@ -5,11 +5,17 @@ namespace hlir { namespace framework { Variable* Scope::FindVar(const std::string& name) const { - auto it = dic.find(name); - if (it != dic.end()) return it->second.get(); + auto it = data_.find(name); + if (it != data_.end()) return it->second.get(); return nullptr; } +Tensor* Scope::GetTensor(const std::string& name) const { + auto* var = FindVar(name); + CHECK(var) << "No variable called [" << name << "] found"; + return &std::get(*var); +} + } // namespace framework } // namespace hlir } // namespace cinn diff --git a/cinn/hlir/framework/scope.h b/cinn/hlir/framework/scope.h index 06b26b2580..322c95f597 100644 --- a/cinn/hlir/framework/scope.h +++ b/cinn/hlir/framework/scope.h @@ -14,6 +14,8 @@ namespace framework { using Variable = std::variant; +struct Tensor; + class Scope { public: Scope() = default; @@ -25,8 +27,10 @@ class Scope { //! Find a variable, get null if not exists. Variable* FindVar(const std::string& name) const; + Tensor* GetTensor(const std::string& name) const; + private: - std::unordered_map> dic; + std::unordered_map> data_; CINN_DISALLOW_COPY_AND_ASSIGN(Scope); }; @@ -36,7 +40,7 @@ Variable* Scope::Var(const std::string& name) { Variable* x = FindVar(name); if (x) return x; auto* data = new Variable(T()); - dic[name].reset(data); + data_[name].reset(data); return data; } diff --git a/cinn/hlir/framework/tensor.h b/cinn/hlir/framework/tensor.h index 5674e7e7af..1d19072cef 100644 --- a/cinn/hlir/framework/tensor.h +++ b/cinn/hlir/framework/tensor.h @@ -8,6 +8,7 @@ #include "cinn/common/common.h" #include "cinn/common/macros.h" #include "cinn/hlir/framework/buffer.h" +#include "cinn/utils/string.h" namespace cinn { namespace hlir { diff --git a/cinn/hlir/op/broadcast.cc b/cinn/hlir/op/broadcast.cc index ca24bb8436..87a6cde6ba 100644 --- a/cinn/hlir/op/broadcast.cc +++ b/cinn/hlir/op/broadcast.cc @@ -1,6 +1,7 @@ #include "cinn/hlir/pe/broadcast.h" #include + #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/op_strategy.h" @@ -92,7 +93,7 @@ std::shared_ptr StrategyForElementwiseMul(const framework::NodeAttr std::vector> InferShapeForElementwise(const std::vector> &inputs_shape, const framework::NodeAttr &attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; + CHECK_EQ(inputs_shape.size(), 2UL); std::vector> res{inputs_shape[0]}; return res; } @@ -197,4 +198,6 @@ CINN_REGISTER_HELPER(broadcast_ops) { .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForScale)) .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForScale)) .set_support_level(4); + + return true; } diff --git a/cinn/hlir/op/nn.cc b/cinn/hlir/op/nn.cc index 36294bf039..7497527c68 100644 --- a/cinn/hlir/op/nn.cc +++ b/cinn/hlir/op/nn.cc @@ -1,4 +1,5 @@ #include "cinn/hlir/pe/nn.h" + #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/op_strategy.h" @@ -276,4 +277,6 @@ CINN_REGISTER_HELPER(nn_ops) { .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForBatchNorm)) .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForBatchNorm)) .set_support_level(4); + + return true; } diff --git a/cinn/hlir/op/transform.cc b/cinn/hlir/op/transform.cc index 453b11cbfc..3dd6e0b0ae 100644 --- a/cinn/hlir/op/transform.cc +++ b/cinn/hlir/op/transform.cc @@ -1,7 +1,9 @@ #include "cinn/hlir/pe/transform.h" + #include "cinn/hlir/framework/node.h" #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/op_strategy.h" +#include "cinn/ir/ir_printer.h" namespace cinn { namespace hlir { @@ -41,6 +43,7 @@ std::shared_ptr StrategyForMul(const framework::NodeAttr &attrs, auto out = pe::Matmul( A.as_tensor_ref(), B.as_tensor_ref(), trans_a, trans_b, x_num_col_dims, y_num_col_dims, UniqName("C")); + VLOG(3) << "matmul out: " << out; auto stages = CreateStages({out}); *ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}}; @@ -61,9 +64,25 @@ std::shared_ptr StrategyForMul(const framework::NodeAttr &attrs, std::vector> InferShapeForMul(const std::vector> &inputs_shape, const framework::NodeAttr &attrs) { - CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape size is 0! Please check again."; - std::vector> res{inputs_shape[0]}; - return res; + VLOG(3) << "Mul shape0: " << utils::Join(inputs_shape[0], ","); + VLOG(3) << "Mul shape1: " << utils::Join(inputs_shape[1], ","); + CHECK(!inputs_shape.empty() && !inputs_shape[0].empty()) << "The input's shape is empty"; + + int x_num_col_dims = -1; + if (attrs.attr_store.count("x_num_col_dims")) { + x_num_col_dims = std::get(attrs.attr_store.at("x_num_col_dims")); + } + int y_num_col_dims = -1; + if (attrs.attr_store.count("y_num_col_dims")) { + y_num_col_dims = std::get(attrs.attr_store.at("y_num_col_dims")); + } + + std::vector out_shape; + for (int i = 0; i < x_num_col_dims; i++) out_shape.push_back(inputs_shape[0][i]); + for (int i = 0; i < y_num_col_dims; i++) out_shape.push_back(inputs_shape[1][inputs_shape.size() - 1 - i]); + + if (out_shape.empty()) return {{1}}; + return {out_shape}; } std::vector InferDtypeForMul(const std::vector &inputs_type, const framework::NodeAttr &attrs) { @@ -85,4 +104,5 @@ CINN_REGISTER_HELPER(transform_ops) { .set_attr("infershape", std::function(cinn::hlir::op::InferShapeForMul)) .set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForMul)) .set_support_level(4); + return true; } diff --git a/cinn/hlir/pass/infershape.cc b/cinn/hlir/pass/infershape.cc index 6bab9277bf..6ca5c25aa3 100644 --- a/cinn/hlir/pass/infershape.cc +++ b/cinn/hlir/pass/infershape.cc @@ -3,6 +3,7 @@ #include "cinn/hlir/framework/op.h" #include "cinn/hlir/framework/pass.h" #include "cinn/hlir/pass/use_pass.h" +#include "cinn/utils/string.h" namespace cinn { namespace hlir { @@ -24,11 +25,13 @@ void InferShapePass(Graph* graph) { auto& op_inferdtype = Operator::GetAttrs(const std::vector&, const framework::NodeAttr&)>>( "inferdtype"); - for (auto& node : store_nodes) { - if (node->is_type()) { + + for (auto& n : store_nodes) { + auto node = n->safe_as(); + if (node) { std::vector> inputs_shape; std::vector inputs_dtype; - for (auto& in_edge : node->inlinks()) { + for (auto& in_edge : node->inlinks_in_order()) { auto* source_node = in_edge->source()->safe_as(); CHECK(source_node); CHECK(shape_dict.count(source_node->id())) << "No shape for " << source_node->id(); @@ -53,6 +56,7 @@ void InferShapePass(Graph* graph) { auto* sink_node = out_edge->sink()->safe_as(); CHECK(sink_node); + VLOG(3) << "Infershape: " << sink_node->id() << " " << utils::Join(out_shape[counter], ","); shape_dict[sink_node->id()] = out_shape[counter]; dtype_dict[sink_node->id()] = out_dtype[counter]; counter++; @@ -75,4 +79,6 @@ CINN_REGISTER_HELPER(passes) { .provide_graph_attr("infershape") .provide_graph_attr("inferdtype") .set_body(cinn::hlir::pass::InferShapePass); + + return true; } diff --git a/cinn/hlir/pe/broadcast.cc b/cinn/hlir/pe/broadcast.cc index fc5b3ea60c..dcef8db4f7 100644 --- a/cinn/hlir/pe/broadcast.cc +++ b/cinn/hlir/pe/broadcast.cc @@ -42,9 +42,10 @@ void GetBroadcastShape(const std::vector& shape1, int size2 = shape2_new.size(); Expr one(1); int i; + // insert common axis from right to left to common_axis for (i = 1; i <= std::min(size1, size2); ++i) { - const _Var_* var1 = shape1[size1 - i].As<_Var_>(); - const _Var_* var2 = shape2_new[size2 - i].As<_Var_>(); + auto* var1 = shape1[size1 - i].as_var(); + auto* var2 = shape2_new[size2 - i].as_var(); if (MathEqual(shape1[size1 - i], shape2_new[size2 - i])) { common_shape->insert(common_shape->begin(), shape1[size1 - i]); broadcast_flag1->emplace_back(true); @@ -73,8 +74,8 @@ void GetBroadcastShape(const std::vector& shape1, broadcast_flag1->emplace_back(true); broadcast_flag2->emplace_back(true); } else { - CHECK(false) << "Incompatible broadcast dims " << shape1[size1 - i] << " and " << shape2_new[size2 - i] - << " in: " << shape1 << " and " << shape2_new << std::endl; + LOG(FATAL) << "Incompatible broadcast dims " << shape1[size1 - i] << " and " << shape2_new[size2 - i] + << " in: " << shape1 << " and " << shape2_new << std::endl; } } if (size1 != size2) { diff --git a/cinn/hlir/pe/transform.cc b/cinn/hlir/pe/transform.cc index 3607d6deda..df395791de 100644 --- a/cinn/hlir/pe/transform.cc +++ b/cinn/hlir/pe/transform.cc @@ -1,10 +1,11 @@ #include "cinn/hlir/pe/transform.h" +#include + #include "cinn/common/ir_util.h" #include "cinn/ir/tensor.h" #include "cinn/lang/compute.h" - -#include +#include "cinn/optim/ir_simplify.h" namespace cinn { namespace hlir { @@ -70,11 +71,12 @@ void GetMatmulIndice(const std::vector& shape1_new, // B reduce axes for (size_t i = 0; i < y_num_col_dims; i++) { reduce_shape2 = reduce_shape2 * shape2_new[i]; - // tempory check - CHECK(MathEqual(shape1_new[shape1_new.size() - 1 - i], shape2_new[i])); + optim::Simplify(&reduce_shape2); indice2->emplace_back((*indice1)[indice1->size() - 1 - i]); } - CHECK(MathEqual(reduce_shape1, reduce_shape2)); + + CHECK(MathEqual(reduce_shape1, reduce_shape2)) + << "reduce shape not match: " << reduce_shape1 << " vs " << reduce_shape2; CHECK_GE(indices.size(), shape2_new.size() - y_num_col_dims); for (size_t i = y_num_col_dims; i < shape2_new.size(); i++) { indice2->emplace_back(indices[x_num_col_dims + i - y_num_col_dims]); @@ -103,6 +105,7 @@ Tensor Matmul(const Tensor& A, std::vector reduce_axes; GetMatmulOutputShape( A->shape, B->shape, &shape1_new, &shape2_new, &output_shape, trans_a, trans_b, x_num_col_dims, y_num_col_dims); + auto fn = [&](const std::vector& indices) { GetMatmulIndice(shape1_new, shape2_new, @@ -116,7 +119,7 @@ Tensor Matmul(const Tensor& A, &reduce_axes); return ReduceSum(A(A_indice) * B(B_indice), Expr()); }; - return Compute(output_shape, fn, name, reduce_axes); + return Compute(output_shape, fn, name, reduce_axes, output_shape); } } // namespace pe diff --git a/cinn/ir/tensor.cc b/cinn/ir/tensor.cc index bac7db802f..52279b0a7d 100644 --- a/cinn/ir/tensor.cc +++ b/cinn/ir/tensor.cc @@ -449,7 +449,8 @@ ir::Tensor _Tensor_::ReshapeCopied(const std::vector &shape, poly::StageMa } Shared CreateStage(Tensor tensor) { - return poly::Stage::New(tensor->GenerateIslDomain(), tensor->body(), tensor.self()); + auto isl_domain = tensor->GenerateIslDomain(); + return poly::Stage::New(isl_domain, tensor->body(), tensor.self()); } } // namespace ir diff --git a/cinn/pybind/frontend.cc b/cinn/pybind/frontend.cc index a617a60ff2..5daeeaa0e8 100644 --- a/cinn/pybind/frontend.cc +++ b/cinn/pybind/frontend.cc @@ -2,7 +2,8 @@ #include #include #include -#include "cinn/common/type.h" + +#include "cinn/common/common.h" #include "cinn/frontend/syntax.h" #include "cinn/hlir/framework/graph.h" #include "cinn/hlir/framework/graph_compiler.h" diff --git a/cinn/pybind/runtime.cc b/cinn/pybind/runtime.cc index f46bf36590..72d7772a7a 100644 --- a/cinn/pybind/runtime.cc +++ b/cinn/pybind/runtime.cc @@ -17,17 +17,17 @@ using py::arg; void BindCinnRuntime(py::module *); cinn_type_t NumpyTypeToCinn(py::dtype dt) { - if (dt == py::dtype::of()) { + if (dt.is(py::dtype::of())) { return cinn_int32_t(); - } else if (dt == py::dtype::of()) { + } else if (dt.is(py::dtype::of())) { return cinn_int64_t(); - } else if (dt == py::dtype::of()) { + } else if (dt.is(py::dtype::of())) { return cinn_uint32_t(); - } else if (dt == py::dtype::of()) { + } else if (dt.is(py::dtype::of())) { return cinn_uint64_t(); - } else if (dt == py::dtype::of()) { + } else if (dt.is(py::dtype::of())) { return cinn_float32_t(); - } else if (dt == py::dtype::of()) { + } else if (dt.is(py::dtype::of())) { return cinn_float64_t(); } return cinn_unk_t(); diff --git a/cinn/runtime/cpu/host_intrinsics.cc b/cinn/runtime/cpu/host_intrinsics.cc index a94baf0f4d..c095e12578 100644 --- a/cinn/runtime/cpu/host_intrinsics.cc +++ b/cinn/runtime/cpu/host_intrinsics.cc @@ -1,9 +1,10 @@ +#include "cinn/runtime/cpu/host_intrinsics.h" + #include #include #include "cinn/backends/extern_func_jit_register.h" #include "cinn/backends/function_prototype.h" -#include "cinn/runtime/cpu/host_intrinsics.h" #include "cinn/runtime/cpu/mkl_math.h" extern "C" { @@ -107,4 +108,6 @@ CINN_REGISTER_HELPER(host_intrinsics) { REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT(bitwise_and); REGISTER_EXTERN_FUNC_2_IN_1_OUT_INT(bitwise_xor); REGISTER_EXTERN_FUNC_1_IN_1_OUT_INT(bitwise_not); + + return true; } diff --git a/cinn/utils/string.cc b/cinn/utils/string.cc index 55f5b97555..762bdd004a 100644 --- a/cinn/utils/string.cc +++ b/cinn/utils/string.cc @@ -26,21 +26,6 @@ std::string StringFormat(const std::string &fmt_str, ...) { return std::string(formatted.get()); } -std::string Join(const std::vector &fields, const std::string &splitter) { - std::stringstream ss; - if (fields.empty()) return ""; - for (int i = 0; i < fields.size() - 1; i++) { - ss << fields[i]; - ss << splitter; - } - - if (fields.size() > 0) { - ss << fields.back(); - } - - return ss.str(); -} - std::string Trim(const std::string &s, const char *empty) { if (s.empty()) return s; auto start = s.find_first_not_of(empty); diff --git a/cinn/utils/string.h b/cinn/utils/string.h index bdd8bca222..17da235aad 100644 --- a/cinn/utils/string.h +++ b/cinn/utils/string.h @@ -22,7 +22,14 @@ std::string StringFormat(const std::string& fmt_str, ...); /** * Join multiple fields to a single string. Similar to Python's str.join method. */ -std::string Join(const std::vector& fields, const std::string& splitter); +template +std::string Join(const std::vector& fields, const std::string& splitter) { + if (fields.empty()) return ""; + std::stringstream ss; + for (int i = 0; i < fields.size() - 1; i++) ss << fields[i] << splitter; + ss << fields.back(); + return ss.str(); +} std::vector Split(const std::string& str, const std::string& splitter); diff --git a/python/tests/test_op_broadcast.py b/python/tests/test_op_broadcast.py index 06ebb6a9c0..44045c4ff1 100644 --- a/python/tests/test_op_broadcast.py +++ b/python/tests/test_op_broadcast.py @@ -59,6 +59,7 @@ def test_op(self): attrs.attr_store = {"axis": 1} self.to_test_op([[3, 2], [2]], [[3, 2]], "elementwise_mul", attrs) + class OpTest_scale_0(SingleOpTester): def create_target_data(self, inputs_data): [X] = inputs_data