From d25dba76b28b8efa1ed12d77532e23388bbec7d4 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Fri, 5 May 2023 10:27:48 +0800 Subject: [PATCH] fix elementwise cannot fuse with reduce when output has fetch (#1386) --- cinn/hlir/pass/fusion_merge_pass_util.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/cinn/hlir/pass/fusion_merge_pass_util.h b/cinn/hlir/pass/fusion_merge_pass_util.h index a37126e764..696a55f1cf 100644 --- a/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/cinn/hlir/pass/fusion_merge_pass_util.h @@ -118,6 +118,14 @@ CONDITION_FUNC(elementwise_fuse_reduce) { } } CHECK(reducer) << "Can't find reduce op in group " << second->group_id; + + // If the elementwise's output should be fetched, the output var cannot be compute inline + // into reduce's loop, in other words, the elementwise's cannot fused into reduce's loop + // Like: group1 = {cast_0}, group2={broadcast_0 -> elementwise_0 -> cast_1 -> reduce_max_0} + if (helper->output_nodes_set_.count(*first->master_nodes.begin())) { + return false; + } + auto input_shape = helper->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); auto reduce_axes = absl::get>(reducer->attrs.attr_store.at("dim"));