From cb32c887b9ee0386782780f33686e92e07dc0a70 Mon Sep 17 00:00:00 2001 From: XiangGao Date: Tue, 9 May 2023 10:32:07 +0800 Subject: [PATCH] fix issue of fill_constant missing dtype (#1402) --- cinn/hlir/pass/constant_folding_pass_util.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cinn/hlir/pass/constant_folding_pass_util.h b/cinn/hlir/pass/constant_folding_pass_util.h index 6d2c495ba4..7e844075a0 100644 --- a/cinn/hlir/pass/constant_folding_pass_util.h +++ b/cinn/hlir/pass/constant_folding_pass_util.h @@ -30,6 +30,7 @@ inline void fold_broadcast_to_constant(const FusionHelperBase* helper, Graph* gr // create constant op. Node* node_tmp = new Node(Operator::Get("fill_constant"), "fill_constant", common::UniqName("fill_constant")); // set node attr + node_tmp->attrs.attr_store["dtype"] = constant_op->attrs.attr_store.at("dtype"); node_tmp->attrs.attr_store["shape"] = shape; node_tmp->attrs.attr_store["value"] = constant_op->attrs.attr_store.at("value"); node_tmp->attrs.attr_store["force_cpu"] = false; @@ -61,6 +62,7 @@ inline void fold_reshape_fill_constant(const FusionHelperBase* helper, Graph* gr // create constant op. Node* node_tmp = new Node(Operator::Get("fill_constant"), "fill_constant", common::UniqName("fill_constant")); // set node attr + node_tmp->attrs.attr_store["dtype"] = constant_op->attrs.attr_store.at("dtype"); node_tmp->attrs.attr_store["shape"] = shape; node_tmp->attrs.attr_store["value"] = constant_op->attrs.attr_store.at("value"); node_tmp->attrs.attr_store["force_cpu"] = false; @@ -108,6 +110,7 @@ inline void fold_squeeze_fill_constant(const FusionHelperBase* helper, Graph* gr } } + node_tmp->attrs.attr_store["dtype"] = constant_op->attrs.attr_store.at("dtype"); node_tmp->attrs.attr_store["shape"] = n_shape; node_tmp->attrs.attr_store["value"] = constant_op->attrs.attr_store.at("value"); node_tmp->attrs.attr_store["force_cpu"] = false; @@ -158,6 +161,7 @@ inline void fold_expand_dims_fill_constant(const FusionHelperBase* helper, Graph } // set node attr + node_tmp->attrs.attr_store["dtype"] = constant_op->attrs.attr_store.at("dtype"); node_tmp->attrs.attr_store["shape"] = n_shape; node_tmp->attrs.attr_store["value"] = constant_op->attrs.attr_store.at("value"); node_tmp->attrs.attr_store["force_cpu"] = false;