Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[WIP]Add ConvertToMoreFusion Pass #591

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cinn/hlir/pass/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ gather_srcs(cinnapi_src SRCS
opfusion.cc
alterlayout.cc
const_propagate.cc
convert_to_more_fusion.cc
)


Expand All @@ -15,3 +16,4 @@ if (NOT WITH_CUDA)
cc_test(test_alterlayout SRCS alterlayout_test.cc DEPS cinncore)
endif()
cc_test(test_const_propagate SRCS const_propagate_test.cc DEPS cinncore)
cc_test(test_convert_to_more_fusion SRCS convert_to_more_fusion_test.cc DEPS cinncore)
102 changes: 102 additions & 0 deletions cinn/hlir/pass/convert_to_more_fusion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/hlir/framework/pass.h"

namespace cinn {
namespace hlir {
namespace pass {

class ConvertStage {
public:
virtual bool CanConvert(const framework::Node*) = 0;
virtual void TryConvert(const framework::Node*, framework::Graph*) = 0;
};

class ConvertPipeline {
public:
template <typename T>
void AddStage() {
stages_.push_back(std::make_unique<T>());
}

void Run(framework::Graph* graph) {
auto grapd_nodes = std::get<0>(graph->topological_order());
for (auto graph_node : grapd_nodes) {
auto op_node = graph_node->safe_as<framework::Node>();
if (!op_node) continue;
for (auto& stage : stages_) {
if (stage->CanConvert(op_node)) {
stage->TryConvert(op_node, graph);
}
}
}
}

private:
std::vector<std::unique_ptr<ConvertStage>> stages_;
};

class ConvertAddRelu : public ConvertStage {
public:
bool CanConvert(const framework::Node* op_node) override {
return "elementwise_add" == op_node->attrs.node_name;
}

void TryConvert(const framework::Node* ewadd, framework::Graph* graph) override {
auto ewadd_out = ewadd->outlinks_in_order().front()->sink()->safe_as<framework::NodeData>();
if (ewadd_out->outlinks().size() == 1) return;
std::vector<framework::Node*> relu_nodes;
for (auto link : ewadd_out->outlinks()) {
auto link_op = link->sink()->safe_as<framework::Node>();
if ("relu" == link_op->attrs.node_name) {
relu_nodes.push_back(link_op);
}
}
size_t start_index = 0;
if (ewadd_out->outlinks().size() == relu_nodes.size()) {
start_index = 1;
}
for (size_t i = start_index; i < relu_nodes.size(); ++i) {
framework::Node* node = new framework::Node(ewadd->op(), ewadd->attrs.node_name, ewadd->id() + "_" + std::to_string(i));
framework::NodeData* node_data = new framework::NodeData(std::shared_ptr<framework::Node>(node), 0, 0, ewadd_out->id() + "_" + std::to_string(i));
for (auto ewadd_inlink : ewadd->inlinks()) {
ewadd_inlink->source()->LinkTo(node);
}
node->LinkTo(node_data);
node_data->LinkTo(relu_nodes[i]);
graph->RegisterNode(node_data->id(), node_data);
graph->RegisterNode(node->id(), node);
ewadd_out->UnLinkTo(relu_nodes[i]);
}
}
};

void ConvertToMoreFusionPass(framework::Graph* graph) {
ConvertPipeline pipeline;
pipeline.AddStage<ConvertAddRelu>();
pipeline.Run(graph);
}

} // namespace pass
} // namespace hlir
} // namespace cinn

CINN_REGISTER_HELPER(ConvertToMoreFusion) {
CINN_REGISTER_PASS(ConvertToMoreFusion)
.describe("This pass")
.set_change_structure(true)
.set_body(cinn::hlir::pass::ConvertToMoreFusionPass);
return true;
}
68 changes: 68 additions & 0 deletions cinn/hlir/pass/convert_to_more_fusion_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>

#include "cinn/cinn.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/op/use_ops.h"
#include "cinn/hlir/pass/use_pass.h"

namespace cinn {
namespace frontend {

TEST(convert_to_more_fusion, simple_convert) {
Placeholder A(Float(32), {3, 3}, "A");
Placeholder B(Float(32), {3, 3}, "B");

Program program;
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
auto ewadd = program.elementwise_add(A, B);
auto relu = program.relu(ewadd);
auto reduce_sum = program.reduce_sum(ewadd, {0, 1});
program.SetInputs({A, B});
program.Validate();
LOG(INFO) << "Program:\n" << program;

auto graph = std::make_shared<hlir::framework::Graph>(program, target);
auto get_ewadd_nodes = [](const common::GraphNode* graph_node) -> bool {
auto op_node = graph_node->safe_as<hlir::framework::Node>();
return op_node && "elementwise_add" == op_node->attrs.node_name;
};
// CollectNodes
LOG(INFO) << "original graph:\n" << graph->Visualize();
auto before_ewadd_nodes = graph->CollectNodes(get_ewadd_nodes);
hlir::framework::ApplyPass(graph.get(), "ConvertToMoreFusion");
auto after_ewadd_nodes = graph->CollectNodes(get_ewadd_nodes);
ASSERT_EQ(before_ewadd_nodes.size(), after_ewadd_nodes.size() - 1);
LOG(INFO) << "after pass graph:\n" << graph->Visualize();
hlir::framework::ApplyPass(graph.get(), "InferShape");
hlir::framework::ApplyPass(graph.get(), "OpFusion");

auto scope = hlir::framework::BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph);
auto runtime_program = gc.Build();
// check compiled program has 2 kernel
ASSERT_EQ(runtime_program->size(), 2);
}

} // namespace frontend
} // namespace cinn
1 change: 1 addition & 0 deletions cinn/hlir/pass/use_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ CINN_USE_REGISTER(InferShape)
CINN_USE_REGISTER(OpFusion)
CINN_USE_REGISTER(AlterLayout)
CINN_USE_REGISTER(ConstPropagate)
CINN_USE_REGISTER(ConvertToMoreFusion)