Skip to content

Commit

Permalink
fix para dtype conversion.
Browse files Browse the repository at this point in the history
  • Loading branch information
Reinerzhou committed Nov 21, 2023
1 parent 017cba8 commit 24594ed
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions dicp/dicp/vendor/TopsGraph/codegen/enflame.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,10 @@ def make_const_if_scalar(op_var, value, dtype=torch.float32, count=0):
if isinstance(value, numbers.Number):
src_code = f"{cxx_type_set[dtype]} {op_var}_const_value{count} = static_cast<{cxx_type_set[dtype]}>({value});\n"
value = f"{op_var}_const{count}"
src_code += f"builder::Type {op_var}_const_type{count}({{1}}, {type_set[dtype]});\n"
src_code += f"builder::Op {value} = builder::Const(hlir_builder, static_cast<void *>(&{op_var}_const_value{count}), {op_var}_const_type{count});\n"
const_type = dtype if dtype != torch.float16 else torch.float32
src_code += f"builder::Op {value} = builder::Const(hlir_builder, static_cast<void *>(&{op_var}_const_value{count}), builder::Type({{1}}, {type_set[const_type]}));\n"
if dtype == torch.float16:
src_code += f"{value} = builder::Convert({value}, builder::Type({{1}}, {type_set[dtype]}));\n"
return src_code, value

@staticmethod
Expand Down Expand Up @@ -614,7 +616,9 @@ def Div(op_var, shape, dtype, x, y, **kwargs_list):

@staticmethod
def Sub(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code_x, x = EnflameOverrides.make_const_if_scalar(op_var, x, dtype)
src_code_y, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code = src_code_x + src_code_y
src_code += f"builder::Op {op_var} = builder::Sub({x}, {y});"
return src_code

Expand Down Expand Up @@ -679,7 +683,7 @@ def Neg(op_var, shape, dtype, x, **kwargs_list):

@staticmethod
def Pow(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y)
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Pow({x}, {y});"
return src_code

Expand Down Expand Up @@ -761,7 +765,8 @@ def Full(op_var, out_shape, out_dtype, size, value, **kwargs_list):

@staticmethod
def FullLike(op_var, shape, dtype, x, value, **kwargs_list):
return f"builder::Op {op_var} = builder::FullLike({x}, {value});"
src_code = f"builder::Op {op_var} = builder::FullLike({x}, {value}, {type_set[dtype]}, {{{str(shape).split('[')[-1].split(']')[0]}}});"
return src_code

@staticmethod
def Transpose(op_var, shape, dtype, x, permution=[0, 1], **kwargs_list):
Expand Down

0 comments on commit 24594ed

Please sign in to comment.