Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Pass api #1527

Open
wants to merge 92 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
25b59fd
add tensor_interface class
zyfncg Jun 9, 2023
dee0847
add '+=' for TensorInterfaceList
zyfncg Jun 9, 2023
7c8bb26
Merge pull request #2 from zyfncg/group_refactor
zyfncg Jun 9, 2023
6c58e0d
polish code
zyfncg Jun 9, 2023
6714d21
Merge pull request #3 from zyfncg/group_refactor
zyfncg Jun 9, 2023
37ce34b
add OpGroupInterface class
zyfncg Jun 9, 2023
922df41
add OpGroupInterface class
zyfncg Jun 9, 2023
c21f972
Merge pull request #4 from zyfncg/group_refactor
zyfncg Jun 9, 2023
a1c2b70
Leave function parameter unchanged
jiahy0825 Jun 9, 2023
530b980
Fix vertical/horizontal fuse funcion parameter
jiahy0825 Jun 9, 2023
8cd38d7
fix shared_ptr bug
jiahy0825 Jun 9, 2023
adc93ef
Merge pull request #5 from jiahy0825/update_comsumer_and_producer_type
jiahy0825 Jun 9, 2023
73fce39
add FusePassContext
zyfncg Jun 9, 2023
59a8ee3
Merge pull request #6 from zyfncg/pass_api
zyfncg Jun 9, 2023
a6d7413
Add topology relevant algorithms: dfs,bfs,scc,topo,is_reachable
jiahy0825 Jun 12, 2023
2bb6bc7
Merge pull request #8 from jiahy0825/topo-relevant-algo
jiahy0825 Jun 12, 2023
f983267
horizontal fuse
jiahy0825 Jun 12, 2023
8b0f948
horizontal fuse, code complete, wait for debug
jiahy0825 Jun 12, 2023
59e3e4d
Pass Horizontal Fuse Test
jiahy0825 Jun 12, 2023
e1dc74d
Merge pull request #9 from jiahy0825/support-horizontal
jiahy0825 Jun 12, 2023
8dfd12c
temp save
zyfncg Jun 12, 2023
433dad2
Merge branch 'pass_api' of github.com:jiahy0825/CINN into pass_api_v0
zyfncg Jun 13, 2023
835889c
revert for debug
zyfncg Jun 13, 2023
740819a
debug change
zyfncg Jun 13, 2023
23f6ac4
fix vertical fuse bug
zyfncg Jun 13, 2023
d15e3f7
polish code
zyfncg Jun 13, 2023
f3c6426
Merge pull request #10 from zyfncg/pass_api_v0
zyfncg Jun 13, 2023
d921da6
Refactor Recompute Fuse
zyfncg Jun 13, 2023
8159609
Merge pull request #11 from zyfncg/pass_api_v0
jiahy0825 Jun 13, 2023
1969a79
Support InputFuse
jiahy0825 Jun 13, 2023
8ab23ff
Merge branch 'pass_api' of https://github.com/jiahy0825/CINN into del…
jiahy0825 Jun 13, 2023
5fa2c7d
update recompute codes
jiahy0825 Jun 13, 2023
bc7d3ad
Fix loop error
jiahy0825 Jun 13, 2023
31caa33
Change to GeneralInputFuse in DoGeneralRecomputeAndVerticalFusion
jiahy0825 Jun 14, 2023
314f01c
fix DetectCycle
zyfncg Jun 14, 2023
2eae2e6
Merge pull request #14 from zyfncg/pass_api_v0
zyfncg Jun 14, 2023
9b9d20b
Merge branch 'pass_api' of https://github.com/jiahy0825/CINN into del…
jiahy0825 Jun 14, 2023
5507b5a
fix some random bug
zyfncg Jun 14, 2023
45e122b
Merge pull request #12 from jiahy0825/delete_relation_map
jiahy0825 Jun 14, 2023
157800c
Merge branch 'pass_api' of github.com:jiahy0825/CINN into pass_api_v0
zyfncg Jun 14, 2023
1e3d4a0
Merge pull request #15 from zyfncg/pass_api_v0
zyfncg Jun 14, 2023
0888e1a
Support pass register
jiahy0825 Jun 15, 2023
ff16f4d
Merge pull request #18 from jiahy0825/register-pass
jiahy0825 Jun 16, 2023
ac56151
update node inferface
zyfncg Jun 19, 2023
052749e
update
zyfncg Jun 20, 2023
f2650ce
update
zyfncg Jun 20, 2023
110a495
update
zyfncg Jun 20, 2023
2c767e4
change shared_ptr to OpGroup object in iterator
zyfncg Jun 20, 2023
c5a2560
modify is_same_size by new interface
zyfncg Jun 20, 2023
0245b9b
refactor HorizontalElementwiseFuseReduce
zyfncg Jun 25, 2023
4ac113b
Merge branch 'develop' of https://github.com/PaddlePaddle/CINN into p…
zyfncg Jun 25, 2023
f8bb713
update
zyfncg Jun 25, 2023
2757279
Merge pull request #20 from zyfncg/pass_api_update
zyfncg Jun 25, 2023
1590a33
Merge branch 'pass_api' of github.com:jiahy0825/CINN into pass_api_v1
zyfncg Jun 25, 2023
929b159
fix bug
zyfncg Jun 25, 2023
fe12227
fix api headerfile not found error
jiahy0825 Jun 26, 2023
12494a9
Merge pull request #21 from jiahy0825/fix-headerfile-not-found
jiahy0825 Jun 26, 2023
cb3fa1f
update interface of op and tensor
zyfncg Jun 26, 2023
96a0e03
refine code
zyfncg Jun 26, 2023
2efd7c3
add Shape class
zyfncg Jun 26, 2023
6d30239
refine view interface
zyfncg Jun 26, 2023
495f0b4
modify shared_ptr of Group
zyfncg Jun 26, 2023
46ace80
fix-accuracy-test-bug
jiahy0825 Jun 27, 2023
007d67f
fix bug
zyfncg Jun 27, 2023
d50f1be
add graph point into group
zyfncg Jun 27, 2023
0a9b7b2
add shape.h file
zyfncg Jun 27, 2023
bda6e17
Merge branch 'pass_api' of github.com:jiahy0825/CINN into pass_api_v1
zyfncg Jun 28, 2023
8e16c50
revert producer and consumer group from map to set in group
zyfncg Jun 28, 2023
0274349
replace ops with WalkOpNodes in op_graph
zyfncg Jun 28, 2023
6925517
delete unused file
zyfncg Jun 28, 2023
ad17701
fix some bugs
jiahy0825 Jun 28, 2023
5406478
delele unused code
zyfncg Jun 28, 2023
ca648f9
develop comment
zyfncg Jun 28, 2023
602ba9d
save workspace
jiahy0825 Jun 28, 2023
c4c4a78
aligned
jiahy0825 Jun 28, 2023
35b40ba
fully aligned
jiahy0825 Jun 29, 2023
d4c5e9c
fix bug
zyfncg Jun 29, 2023
87e9a47
polish code
zyfncg Jun 29, 2023
6dc511d
fully aligned and add debug message
jiahy0825 Jun 30, 2023
a40aab7
Add trick for BERT
jiahy0825 Jun 30, 2023
57557d5
update readme
zyfncg Jun 30, 2023
172ba95
Remove IsDependency Trick
jiahy0825 Jun 30, 2023
91f67a6
Remove trick to HorizontalFusePass
jiahy0825 Jul 3, 2023
c9f6b75
add utils header
zyfncg Jul 3, 2023
34be0ee
remove debug log
jiahy0825 Jul 3, 2023
dd334db
remove debug log
jiahy0825 Jul 3, 2023
a0c271f
Merge pull request #22 from jiahy0825/fix-accuracy-test-bug
jiahy0825 Jul 3, 2023
9ade8a9
polish code
zyfncg Jul 3, 2023
695bf6b
Merge branch 'pass_api' of github.com:jiahy0825/CINN into pass_api_v1
zyfncg Jul 3, 2023
8d81033
Merge pull request #19 from zyfncg/pass_api_v1
zyfncg Jul 3, 2023
b0f39ff
change logic of get master node
zyfncg Jul 4, 2023
3e7508b
Merge pull request #25 from zyfncg/pass_api_v1
jiahy0825 Jul 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ if (WITH_TESTING)
cc_library(cinn_gtest_main SRCS gtest_main.cc DEPS gtest)
endif()

add_subdirectory(api)
add_subdirectory(auto_schedule)
add_subdirectory(common)
add_subdirectory(utils)
Expand Down
8 changes: 8 additions & 0 deletions cinn/api/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
core_gather_headers()

gather_srcs(cinnapi_src SRCS
op_node.cc
tensor_node.cc
)

message(STATUS "srcs: ${cinnapi_src}")
44 changes: 44 additions & 0 deletions cinn/api/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
The classes in this directory are the interface of group fusion pass, you can use these apis to build the stragey for group fusion.

The Class and APIs are following:

`OpGroup` : A set of op nodes, which will pass to cinn backend for generating kernel code. Two groups can fuse togather according to the rule of merging written in the passes.

`OpNode` : Map the op in the program.

`TensorNode` : Map the tensor in the program.

`Shape` : The shape infomation of tensor

`FusePassCtx` : The context is the parameter for the pass, it hold the data all you need in the pass.

`FuseHelper` : We provide some util methods such as `DetectCycleIfFuse` in fuse_helper to simplify development of pass.

| Class | method | description |
| :--: | :--: | :--: |
| `OpGroup` | kind()| Get the Kind of group |
| | producers()| Get producer groups of current group |
| | consumers() | Get consumer groups of current group |
| | WalkOpNodes(const std::function<void(const OpNode&)>& VisitOpNode) | Visit the op_nodes in the group and execute the VisitOpNode function for each OpNode |
| | | |
| `OpNode` | kind() | Get the Kind of op_node |
| | inputs() | Get input tensors of op_node |
| | outputs() | Get output tensors of op_node |
| | GetAttr(const std::string& attr_name) | Get attribute of op_node by attr name |
| | | |
| `TensorNode` | shape() | Get shape of tensor |
| | producer() | Get the producer op_node of tensor |
| | consumers() | Get the consumer op_nodes of tensor |
| | | |
| `Shape` | numel() | Get total number of elements in the shape |
| | other methods are same with std::vector<int64_t> | |
| | | |
| `LightwareFusePassCtx` | PickOpGroup() | Get the current group in the pass context |
| | void EnableFuse(const OpGroup& first, const OpGroup& second) | Mark the two groups which can fuse togather |
| | fuse_helper() | Get the fuse_helper provided by pass context |
| | | |
| `InputFusePassCtx` | PickConsumersWithSameInputs() | Get all consumer groups for input tensors of graph |
| | void EnableFuse(const OpGroup& first, const OpGroup& second) | Mark the two groups which can fuse togather |
| | fuse_helper() | Get the fuse_helper provided by pass context |
| | | |
| `FuseHelper` | DetectCycleIfFuse(const OpGroup& first, const OpGroup& second) | Whether there is cycle in graph after fusing two groups |
155 changes: 155 additions & 0 deletions cinn/api/op_group.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>

#include "cinn/api/op_node.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/pass/fusion_helper_base.h"

namespace cinn {
namespace api {

using Comparator = hlir::framework::Graph::Group::SharedGroupComparator;
using Hasher = hlir::framework::Graph::Group::SharedGroupHasher;

class OpGroup {
public:
OpGroup(const std::shared_ptr<hlir::framework::Graph::Group>& group) : group_(group) {}

OpGroup(const OpGroup& other) = default;

class OpGroupListIterator {
public:
OpGroupListIterator(
std::unordered_set<std::shared_ptr<hlir::framework::Graph::Group>, Hasher, Comparator>::const_iterator it)
: iter_(it) {}

OpGroupListIterator& operator++() {
++iter_;
return *this;
}

OpGroupListIterator operator++(int) {
OpGroupListIterator tmp = *this;
++iter_;
return tmp;
}

bool operator==(const OpGroupListIterator& other) const { return iter_ == other.iter_; }

bool operator!=(const OpGroupListIterator& other) const { return !(*this == other); }

OpGroup operator*() const { return OpGroup(*iter_); }

private:
std::unordered_set<std::shared_ptr<hlir::framework::Graph::Group>, Hasher, Comparator>::const_iterator iter_;
};

class ProducerOpGroupListView {
public:
ProducerOpGroupListView(const std::weak_ptr<hlir::framework::Graph::Group>& group) : group_(group) {}

ProducerOpGroupListView(const ProducerOpGroupListView& other) = delete;
ProducerOpGroupListView(ProducerOpGroupListView&& other) = delete;

ProducerOpGroupListView& operator=(const ProducerOpGroupListView& other) = delete;

using const_iterator = OpGroupListIterator;

size_t size() const { return group_.lock()->producer_groups().size(); }

const_iterator begin() const { return const_iterator(group_.lock()->producer_groups().begin()); }

const_iterator end() const { return const_iterator(group_.lock()->producer_groups().end()); }

private:
const std::weak_ptr<hlir::framework::Graph::Group> group_;
};

class ConsumerOpGroupListView {
public:
ConsumerOpGroupListView(const std::weak_ptr<hlir::framework::Graph::Group>& group) : group_(group) {}

ConsumerOpGroupListView(const ConsumerOpGroupListView& other) = delete;
ConsumerOpGroupListView(ConsumerOpGroupListView&& other) = delete;

ConsumerOpGroupListView& operator=(const ConsumerOpGroupListView& other) = delete;

using const_iterator = OpGroupListIterator;

size_t size() const { return group_.lock()->consumer_groups().size(); }

const_iterator begin() const { return const_iterator(group_.lock()->consumer_groups().begin()); }

const_iterator end() const { return const_iterator(group_.lock()->consumer_groups().end()); }

private:
const std::weak_ptr<hlir::framework::Graph::Group> group_;
};

const std::string& group_id() const { return group_.lock()->group_id; }

hlir::framework::OpPatternKind kind() const { return group_.lock()->kind(); }

// The WalkOpNodes function is used to traverse the op_nodes in the group and execute
// the VisitOpNode function for each OpNode. This function is equivalent to for loop
// for op_nodes in graph.
//
// In order to avoid unnecessary memory copies, we use WalkOpNodes function instead of
// providing a function to get all op_nodes directly.
//
// Example: Get the all Reduction op_nodes in the group.
// OpGroup group = ...;
// std::set<api::OpNode> reduce_op_set;
// group.WalkOpNodes([&reduce_op_set](const api::OpNode& op){
// // The lambda funtion of VisitOpNode to get reduction op_nodes.
// if (op.kind() == OpPatternKind::kReduction) {
// reduce_op_set.insert(op);
// }
// });
void WalkOpNodes(const std::function<void(const OpNode&)>& VisitOpNode) const {
group_.lock()->WalkNodes(
[&](const hlir::framework::Node* node) { VisitOpNode(OpNode(node, group_.lock()->graph_)); });
}

ProducerOpGroupListView producers() const { return ProducerOpGroupListView(group_); }

ConsumerOpGroupListView consumers() const { return ConsumerOpGroupListView(group_); }

std::shared_ptr<hlir::framework::Graph::Group> GetGroup() const { return group_.lock(); }

bool operator==(const OpGroup& other) const { return group_.lock().get() == other.group_.lock().get(); }

bool operator<(const OpGroup& other) const { return group_.lock().get() < other.group_.lock().get(); }

private:
const std::weak_ptr<hlir::framework::Graph::Group> group_;
};

} // namespace api
} // namespace cinn

namespace std {

template <>
struct hash<cinn::api::OpGroup> {
size_t operator()(const cinn::api::OpGroup& obj) const {
return std::hash<int64_t>()(reinterpret_cast<uint64_t>(obj.GetGroup().get()));
}
};

} // namespace std
29 changes: 29 additions & 0 deletions cinn/api/op_node.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/api/op_node.h"

namespace cinn {
namespace api {

TensorNode OpNode::InputTensorListView::operator[](size_t index) const {
return TensorNode(edges_[index]->source()->safe_as<hlir::framework::NodeData>(), graph_);
}

TensorNode OpNode::OutputTensorListView::operator[](size_t index) const {
return TensorNode(edges_[index]->sink()->safe_as<hlir::framework::NodeData>(), graph_);
}

} // namespace api
} // namespace cinn
133 changes: 133 additions & 0 deletions cinn/api/op_node.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/pass/fusion_helper_base.h"
#include "cinn/api/tensor_node.h"
#include "cinn/hlir/framework/op.h"

namespace cinn {
namespace api {

using OpPatternKind = cinn::hlir::framework::OpPatternKind;
using Attribute = cinn::utils::Attribute;

class OpNode {
public:
OpNode(const hlir::framework::Node* node, const hlir::framework::Graph* graph) : node_(node), graph_(graph), input_tensors_(node->inlinks_in_order(), graph_), output_tensors_(node->outlinks_in_order(), graph_) {
VLOG(1) << "[OpNode] node: " << node->id();
}

OpNode(const OpNode& other) : node_(other.node_), graph_(other.graph_), input_tensors_(node_->inlinks_in_order(), graph_), output_tensors_(node_->outlinks_in_order(), graph_) {}

OpPatternKind kind () const {
thread_local const static hlir::framework::OpValueType<OpPatternKind>& op_pattern_dict = hlir::framework::Operator::GetAttrs<OpPatternKind>("OpPattern");
auto kind = op_pattern_dict[node_->op()];

if (kind == hlir::framework::kBroadcast) {
// As binary op was defined as broadcast, actually it should be element-wise.
if (node_->op()->name != "broadcast_to") {
return hlir::framework::kElementWise;
}
}
return kind;
}

class InputTensorListView {
public:
InputTensorListView(const std::vector<common::Shared<common::GraphEdge>>& edges, const hlir::framework::Graph* graph) : edges_(edges), graph_(graph) {}

InputTensorListView(const InputTensorListView& other) = delete;
InputTensorListView(InputTensorListView&& other) = delete;

InputTensorListView& operator=(const InputTensorListView& other) = delete;

size_t size() const { return edges_.size(); }

TensorNode operator[](size_t index) const;

private:
std::vector<common::Shared<common::GraphEdge>> edges_;
const hlir::framework::Graph* graph_;
};

class OutputTensorListView {
public:
OutputTensorListView(const std::vector<common::Shared<common::GraphEdge>>& edges, const hlir::framework::Graph* graph) : edges_(edges), graph_(graph) {}

OutputTensorListView(const OutputTensorListView& other) = delete;
OutputTensorListView(OutputTensorListView&& other) = delete;

OutputTensorListView& operator=(const OutputTensorListView& other) = delete;

size_t size() const { return edges_.size(); }

TensorNode operator[](size_t index) const;

private:
std::vector<common::Shared<common::GraphEdge>> edges_;
const hlir::framework::Graph* graph_;
};

bool operator == (const OpNode& other) const {
return node_ == other.node_;
}

bool operator < (const OpNode& other) const {
return node_ < other.node_;
}

const InputTensorListView& inputs() const {
return input_tensors_;
}

const OutputTensorListView& outputs() const {
return output_tensors_;
}

template <typename T>
const T& GetAttr(const std::string& attr_name) const {
return absl::get<T>(GetAttr(attr_name));
}

private:
const Attribute& GetAttr(const std::string& attr_name) const {
return node_->attrs.attr_store.at(attr_name);
}

friend struct std::hash<OpNode>;

const hlir::framework::Node* node_;
const hlir::framework::Graph* graph_;

const InputTensorListView input_tensors_;
const OutputTensorListView output_tensors_;
};

} // namespace api
} // namespace cinn

namespace std {

template <>
struct hash<cinn::api::OpNode> {
size_t operator()(const cinn::api::OpNode& obj) const {
return std::hash<int64_t>()(reinterpret_cast<uint64_t>(obj.node_));
}
};

} // namespace std
Loading