diff --git a/cinn/hlir/pass/CMakeLists.txt b/cinn/hlir/pass/CMakeLists.txt index e709eb6105..4daeceba9f 100755 --- a/cinn/hlir/pass/CMakeLists.txt +++ b/cinn/hlir/pass/CMakeLists.txt @@ -6,6 +6,7 @@ gather_srcs(cinnapi_src SRCS opfusion.cc alterlayout.cc const_propagate.cc + convert_to_more_fusion.cc ) @@ -15,3 +16,4 @@ if (NOT WITH_CUDA) cc_test(test_alterlayout SRCS alterlayout_test.cc DEPS cinncore) endif() cc_test(test_const_propagate SRCS const_propagate_test.cc DEPS cinncore) +cc_test(test_convert_to_more_fusion SRCS convert_to_more_fusion_test.cc DEPS cinncore) diff --git a/cinn/hlir/pass/convert_to_more_fusion.cc b/cinn/hlir/pass/convert_to_more_fusion.cc new file mode 100644 index 0000000000..60d79eb842 --- /dev/null +++ b/cinn/hlir/pass/convert_to_more_fusion.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cinn/hlir/framework/pass.h" + +namespace cinn { +namespace hlir { +namespace pass { + +class ConvertStage { + public: + virtual bool CanConvert(const framework::Node*) = 0; + virtual void TryConvert(const framework::Node*, framework::Graph*) = 0; +}; + +class ConvertPipeline { + public: + template + void AddStage() { + stages_.push_back(std::make_unique()); + } + + void Run(framework::Graph* graph) { + auto grapd_nodes = std::get<0>(graph->topological_order()); + for (auto graph_node : grapd_nodes) { + auto op_node = graph_node->safe_as(); + if (!op_node) continue; + for (auto& stage : stages_) { + if (stage->CanConvert(op_node)) { + stage->TryConvert(op_node, graph); + } + } + } + } + + private: + std::vector> stages_; +}; + +class ConvertAddRelu : public ConvertStage { + public: + bool CanConvert(const framework::Node* op_node) override { + return "elementwise_add" == op_node->attrs.node_name; + } + + void TryConvert(const framework::Node* ewadd, framework::Graph* graph) override { + auto ewadd_out = ewadd->outlinks_in_order().front()->sink()->safe_as(); + if (ewadd_out->outlinks().size() == 1) return; + std::vector relu_nodes; + for (auto link : ewadd_out->outlinks()) { + auto link_op = link->sink()->safe_as(); + if ("relu" == link_op->attrs.node_name) { + relu_nodes.push_back(link_op); + } + } + size_t start_index = 0; + if (ewadd_out->outlinks().size() == relu_nodes.size()) { + start_index = 1; + } + for (size_t i = start_index; i < relu_nodes.size(); ++i) { + framework::Node* node = new framework::Node(ewadd->op(), ewadd->attrs.node_name, ewadd->id() + "_" + std::to_string(i)); + framework::NodeData* node_data = new framework::NodeData(std::shared_ptr(node), 0, 0, ewadd_out->id() + "_" + std::to_string(i)); + for (auto ewadd_inlink : ewadd->inlinks()) { + ewadd_inlink->source()->LinkTo(node); + } + node->LinkTo(node_data); + node_data->LinkTo(relu_nodes[i]); + graph->RegisterNode(node_data->id(), node_data); + graph->RegisterNode(node->id(), node); + ewadd_out->UnLinkTo(relu_nodes[i]); + } + } +}; + +void ConvertToMoreFusionPass(framework::Graph* graph) { + ConvertPipeline pipeline; + pipeline.AddStage(); + pipeline.Run(graph); +} + +} // namespace pass +} // namespace hlir +} // namespace cinn + +CINN_REGISTER_HELPER(ConvertToMoreFusion) { + CINN_REGISTER_PASS(ConvertToMoreFusion) + .describe("This pass") + .set_change_structure(true) + .set_body(cinn::hlir::pass::ConvertToMoreFusionPass); + return true; +} diff --git a/cinn/hlir/pass/convert_to_more_fusion_test.cc b/cinn/hlir/pass/convert_to_more_fusion_test.cc new file mode 100644 index 0000000000..caa40011f0 --- /dev/null +++ b/cinn/hlir/pass/convert_to_more_fusion_test.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2021 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "cinn/cinn.h" +#include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/hlir/framework/pass.h" +#include "cinn/hlir/op/use_ops.h" +#include "cinn/hlir/pass/use_pass.h" + +namespace cinn { +namespace frontend { + +TEST(convert_to_more_fusion, simple_convert) { + Placeholder A(Float(32), {3, 3}, "A"); + Placeholder B(Float(32), {3, 3}, "B"); + + Program program; +#ifdef CINN_WITH_CUDA + Target target = common::DefaultNVGPUTarget(); +#else + Target target = common::DefaultHostTarget(); +#endif + auto ewadd = program.elementwise_add(A, B); + auto relu = program.relu(ewadd); + auto reduce_sum = program.reduce_sum(ewadd, {0, 1}); + program.SetInputs({A, B}); + program.Validate(); + LOG(INFO) << "Program:\n" << program; + + auto graph = std::make_shared(program, target); + auto get_ewadd_nodes = [](const common::GraphNode* graph_node) -> bool { + auto op_node = graph_node->safe_as(); + return op_node && "elementwise_add" == op_node->attrs.node_name; + }; + // CollectNodes + LOG(INFO) << "original graph:\n" << graph->Visualize(); + auto before_ewadd_nodes = graph->CollectNodes(get_ewadd_nodes); + hlir::framework::ApplyPass(graph.get(), "ConvertToMoreFusion"); + auto after_ewadd_nodes = graph->CollectNodes(get_ewadd_nodes); + ASSERT_EQ(before_ewadd_nodes.size(), after_ewadd_nodes.size() - 1); + LOG(INFO) << "after pass graph:\n" << graph->Visualize(); + hlir::framework::ApplyPass(graph.get(), "InferShape"); + hlir::framework::ApplyPass(graph.get(), "OpFusion"); + + auto scope = hlir::framework::BuildScope(target, graph); + hlir::framework::GraphCompiler gc(target, scope, graph); + auto runtime_program = gc.Build(); + // check compiled program has 2 kernel + ASSERT_EQ(runtime_program->size(), 2); +} + +} // namespace frontend +} // namespace cinn diff --git a/cinn/hlir/pass/use_pass.h b/cinn/hlir/pass/use_pass.h index a3423981fb..79a5dcfe6b 100644 --- a/cinn/hlir/pass/use_pass.h +++ b/cinn/hlir/pass/use_pass.h @@ -20,3 +20,4 @@ CINN_USE_REGISTER(InferShape) CINN_USE_REGISTER(OpFusion) CINN_USE_REGISTER(AlterLayout) CINN_USE_REGISTER(ConstPropagate) +CINN_USE_REGISTER(ConvertToMoreFusion)