From 24594ede95cc5ac72bc2b124a56a52ca172eb899 Mon Sep 17 00:00:00 2001 From: Reinerzhou <1768552509@qq.com> Date: Tue, 21 Nov 2023 08:28:30 +0000 Subject: [PATCH] fix para dtype conversion. --- dicp/dicp/vendor/TopsGraph/codegen/enflame.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/dicp/dicp/vendor/TopsGraph/codegen/enflame.py b/dicp/dicp/vendor/TopsGraph/codegen/enflame.py index a4dd63144..e17e33ab8 100644 --- a/dicp/dicp/vendor/TopsGraph/codegen/enflame.py +++ b/dicp/dicp/vendor/TopsGraph/codegen/enflame.py @@ -582,8 +582,10 @@ def make_const_if_scalar(op_var, value, dtype=torch.float32, count=0): if isinstance(value, numbers.Number): src_code = f"{cxx_type_set[dtype]} {op_var}_const_value{count} = static_cast<{cxx_type_set[dtype]}>({value});\n" value = f"{op_var}_const{count}" - src_code += f"builder::Type {op_var}_const_type{count}({{1}}, {type_set[dtype]});\n" - src_code += f"builder::Op {value} = builder::Const(hlir_builder, static_cast(&{op_var}_const_value{count}), {op_var}_const_type{count});\n" + const_type = dtype if dtype != torch.float16 else torch.float32 + src_code += f"builder::Op {value} = builder::Const(hlir_builder, static_cast(&{op_var}_const_value{count}), builder::Type({{1}}, {type_set[const_type]}));\n" + if dtype == torch.float16: + src_code += f"{value} = builder::Convert({value}, builder::Type({{1}}, {type_set[dtype]}));\n" return src_code, value @staticmethod @@ -614,7 +616,9 @@ def Div(op_var, shape, dtype, x, y, **kwargs_list): @staticmethod def Sub(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype) + src_code_x, x = EnflameOverrides.make_const_if_scalar(op_var, x, dtype) + src_code_y, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype) + src_code = src_code_x + src_code_y src_code += f"builder::Op {op_var} = builder::Sub({x}, {y});" return src_code @@ -679,7 +683,7 @@ def Neg(op_var, shape, dtype, x, **kwargs_list): @staticmethod def Pow(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y) + src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype) src_code += f"builder::Op {op_var} = builder::Pow({x}, {y});" return src_code @@ -761,7 +765,8 @@ def Full(op_var, out_shape, out_dtype, size, value, **kwargs_list): @staticmethod def FullLike(op_var, shape, dtype, x, value, **kwargs_list): - return f"builder::Op {op_var} = builder::FullLike({x}, {value});" + src_code = f"builder::Op {op_var} = builder::FullLike({x}, {value}, {type_set[dtype]}, {{{str(shape).split('[')[-1].split(']')[0]}}});" + return src_code @staticmethod def Transpose(op_var, shape, dtype, x, permution=[0, 1], **kwargs_list):