-
Notifications
You must be signed in to change notification settings - Fork 114
upgrade op fusion lowering #1216
base: develop
Are you sure you want to change the base?
upgrade op fusion lowering #1216
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
大哥,你这PR好像做了很多非op lowering重构的工作啊,可以拆成多个PR么?
cinn/backends/compiler.cc
Outdated
@@ -128,6 +128,7 @@ void Compiler::CompileCudaModule(const Module& module, const std::string& code) | |||
|
|||
backends::nvrtc::Compiler compiler; | |||
|
|||
VLOG(3) << "[CUDA] device code:\n" << source_code; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SourceCodePrint
就会打印这个源码,而且可以保证一个程序会将所有子图打印出来
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
cinn/common/cas.cc
Outdated
@@ -2016,6 +2016,7 @@ Expr CasSimplifyMutator::SimplifyFracOp(Expr expr) { | |||
}; | |||
|
|||
{ | |||
/* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
冗余代码应该删除而非注释,而且注释应该统一用//
而非/* */
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
cinn/common/cas_test.cc
Outdated
@@ -362,7 +362,7 @@ TEST(CAS, SimplifyMinMax) { | |||
LOG(INFO) << "p0 " << p0; | |||
auto p2 = CasSimplify(p0); | |||
LOG(INFO) << "simplified " << p2; | |||
EXPECT_EQ(GetStreamCnt(p2), "cinn_min(7, ((x) / (2)))"); | |||
// EXPECT_EQ(GetStreamCnt(p2), "cinn_min(7, ((x) / (2)))"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是单测过不了了么?还是改之后结果是对的,但不是这样了?如果是前者那不能简单的注释掉啊,如果是后者那把这改为正确的值不就行了么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -114,6 +114,14 @@ class Graph : public cinn::common::Graph { | |||
} | |||
} | |||
|
|||
std::unordered_set<Node*> NodeSet() { | |||
std::unordered_set<Node*> node_set; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这函数功能完全和CollectNodes
重复了而且也不常用啊。。。在用的地方直接定义
const auto& nodes = group->CollectNodes();
std::unordered_set<Node*> node_set(nodes.begin(), nodes.end());
不好么。。。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用的地方比较多 这里实现比较方便
cinn/hlir/framework/op_lowering.cc
Outdated
@@ -101,7 +65,7 @@ std::vector<ir::LoweredFunc> OpLowerer::LowerWithoutSchedule(GroupPtr& group) { | |||
LOG(FATAL) << "Group Pattern Kind kNonFusible Is Not Implemented!"; | |||
} | |||
} else { | |||
LOG(FATAL) << "Previous IR Schedule Is Not Implemented!"; | |||
LOG(FATAL) << "Previous IR Schedule Is Unsupport Now!"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LOG(FATAL) << "Previous IR Schedule Is Unsupport Now!"; | |
LOG(FATAL) << "Previous IR Schedule Unsupported Now, Please set FLAGS_cinn_ir_schedule=1 to use new IR Schedule"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
cinn/hlir/framework/op_lowering.cc
Outdated
auto loops = ir_sch.GetLoops(GetNodeData(node)->id()); | ||
if (op_pattern_dict[node->op()] == framework::kElementWise) { | ||
ir_sch.FlattenLoops(loops, true); | ||
} else if (op_pattern_dict[node->op()] == framework::kReduction) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这if里啥都没有啊?那为啥不直接
else if (op_pattern_dict[node->op()] != framework::kReduction) {
ir_sch.FlattenLoops(loops, false);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
return tensors; | ||
} | ||
|
||
NodeData* GetNodeData(const Node* node) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这函数名名不副实啊。。。
NodeData* GetNodeData(const Node* node) { | |
NodeData* GetFirstOutputNodeData(const Node* node) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我感觉这个名字没啥问题,而且用了很久了。
return node_data; | ||
} | ||
|
||
std::vector<NodeData*> GetAllNodeData(const Node* node) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::vector<NodeData*> GetAllNodeData(const Node* node) { | |
std::vector<NodeData*> GetAllOutputNodeDatas(const Node* node) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
主要是op_loweing的,其他的工作只有一点点。 |
…pgrad_op_lowering
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
cinn/hlir/framework/op_lowering.cc
Outdated
stage->SimpleComputeAt(master_stage, master_stage->n_out_dims() - 1); | ||
// do schedule | ||
for (auto node : nodes_in_order) { | ||
LOG(INFO) << node->id(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LOG(INFO)
。。。
这个PR主要是对融合算子lowering到ast进行升级,增加扩展性和兼容性,能够适配更加复杂的融合算子生成。
将elemenwise/kinjective/kbroadcast/reduce的循环融合放在一起。
此外,删除了旧的调度原语上的循环融合。