Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
SunNy820828449 committed Mar 13, 2023
1 parent ab14a30 commit 18f90a2
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 50 deletions.
48 changes: 0 additions & 48 deletions cinn/hlir/framework/op_lowering_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,52 +54,6 @@ void CodeGen(ir::LoweredFunc& func) {
#endif
}

/*
TEST(OpFusionPass, Reduce_Fuse_Reduce_TEST_00) {
int h = 32, w = 1024;
NetBuilder net_builder("Reduce_Fuse_Reduce_TEST_00");
// create model
{
auto A = net_builder.CreateInput(Float(32), {h, h, w}, "A");
auto B = net_builder.CreateInput(Float(32), {h, h, w}, "B");
auto C = net_builder.CreateInput(Float(32), {h, h, w}, "C");
auto D = net_builder.CreateInput(Float(32), {h, h, w}, "D");
auto E = net_builder.ReduceSum(A, {1, 2});
auto EE = net_builder.Exp(E);
auto EEE = net_builder.Add(EE, EE);
auto F = net_builder.ReduceSum(B, {1, 2});
auto FF = net_builder.Exp(F);
auto FFF = net_builder.Add(FF, FF);
auto G = net_builder.ReduceSum(C, {1, 2});
auto H = net_builder.ReduceSum(D, {1, 2});
auto I = net_builder.Add(EEE, FFF);
auto J = net_builder.Add(I, G);
auto K = net_builder.Add(J, H);
}
auto program = net_builder.Build();
auto target = common::DefaultTarget();
RunDecomposer(&program, target);
auto graph = std::make_shared<hlir::framework::Graph>(program, target);
hlir::framework::ApplyPass(graph.get(), "OpFusionPass");
hlir::framework::ApplyPass(graph.get(), "FusionMergePass");
CHECK_EQ(graph->fusion_groups.size(), 1);
auto& dtype_dict = graph->GetMutableAttrs<absl::flat_hash_map<std::string, Type>>("inferdtype");
auto& shape_dict = graph->GetMutableAttrs<absl::flat_hash_map<std::string, shape_t>>("infershape");
OpLowerer op_lowerer(dtype_dict, shape_dict, target);
for (auto& fusion_op : graph->fusion_groups) {
auto lowered_func = op_lowerer.Lower(fusion_op);
CHECK_EQ(lowered_func.size(), 1);
CodeGen(lowered_func[0]);
}
}
*/

TEST(OpFusionPass, Reduce_Fuse_Broadcast_Layernorm) {
int h = 32, w = 1024;
NetBuilder net_builder("Reduce_Fuse_Broadcast_Layernorm");
Expand Down Expand Up @@ -196,8 +150,6 @@ TEST(OpFusionPass, Reduce_Fuse_Broadcast_Softmax) {
CHECK_EQ(lowered_func.size(), 1);
CodeGen(lowered_func[0]);
}

exit(0);
}

TEST(OpFusionPass, Reduce_Fuse_Broadcast_1) {
Expand Down
4 changes: 2 additions & 2 deletions cinn/hlir/framework/op_lowering_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ void LoopAssignReduce(ir::IRSchedule& ir_sch,
}
}

auto copy_loop_info = [](std::vector<ir::Expr>& rloops, std::vector<ir::Expr>& loops) {
auto copy_loop_info = [](std::vector<ir::Expr>& loops, std::vector<ir::Expr>& rloops) {
for (int idx = 0; idx < std::min(rloops.size(), loops.size()); ++idx) {
auto l0 = rloops[idx].As<ir::For>();
auto l1 = loops[idx].As<ir::For>();
Expand Down Expand Up @@ -691,7 +691,7 @@ void LoopAssignReduce(ir::IRSchedule& ir_sch,
ir_sch.Split(loops.back(), factors);
loops = ir_sch.GetLoops(node_data->id());
// copy loop info form rloops.
copy_loop_info(rloops, loops);
copy_loop_info(loops, rloops);
return;
}

Expand Down

0 comments on commit 18f90a2

Please sign in to comment.