diff --git a/cinn/hlir/pass/fusion_merge_pass_util.h b/cinn/hlir/pass/fusion_merge_pass_util.h index a37126e764..696a55f1cf 100644 --- a/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/cinn/hlir/pass/fusion_merge_pass_util.h @@ -118,6 +118,14 @@ CONDITION_FUNC(elementwise_fuse_reduce) { } } CHECK(reducer) << "Can't find reduce op in group " << second->group_id; + + // If the elementwise's output should be fetched, the output var cannot be compute inline + // into reduce's loop, in other words, the elementwise's cannot fused into reduce's loop + // Like: group1 = {cast_0}, group2={broadcast_0 -> elementwise_0 -> cast_1 -> reduce_max_0} + if (helper->output_nodes_set_.count(*first->master_nodes.begin())) { + return false; + } + auto input_shape = helper->shape_dict_.at(reducer->inlinks_in_order()[0]->source()->id()); auto reduce_axes = absl::get>(reducer->attrs.attr_store.at("dim"));