Skip to content

Commit

Permalink
fix some paras error and rename device name.
Browse files Browse the repository at this point in the history
  • Loading branch information
Reinerzhou committed Nov 21, 2023
1 parent 094bfe0 commit 5ec9c65
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 58 deletions.
28 changes: 8 additions & 20 deletions dicp/dicp/vendor/TopsGraph/codegen/enflame.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,33 +256,21 @@ def gen_run_func_code(self):
func_body = IndentedBuffer()
func_body.writeline('std::vector<void *> input_ptrs;')
for i in range(0, len(self.input_args)):
func_body.writeline(
f'input_ptrs.emplace_back(static_cast<void *>(input_ptr{str(i)}));')
func_body.writeline(f'input_ptrs.emplace_back(inputs_ptr[{str(i)}]);')

func_body.writeline("")
func_body.writeline("std::vector<void *> output_ptrs;")
output_ptr_count = 0
for i in range(0, len(self.output_args)):
if not isinstance(self.output_args[i], type(None)):
func_body.writeline(
f'output_ptrs.emplace_back(output_ptr{str(i)});')

func_body.writeline(f'output_ptrs.emplace_back(outputs_ptr[{output_ptr_count}]);')
output_ptr_count += 1
func_body.writeline("")
func_body.writeline(
f'run(exe_ptr, dipu_stream, input_ptrs, output_ptrs, {self.device_id}, {"true" if dipu_flag else "false"});')

input_paras = []
for i in range(0, len(self.input_args)):
input_paras.append(f"float* input_ptr{str(i)}")

output_paras = []
for i in range(0, len(self.output_args)):
if not isinstance(self.output_args[i], type(None)):
output_paras.append(f"float* output_ptr{str(i)}")
paras = input_paras + output_paras
func_body.writeline(f'run(exe_ptr, dipu_stream, input_ptrs, output_ptrs, {self.device_id}, {"true" if dipu_flag else "false"});')

run_func_code = IndentedBuffer()
run_func_code.writeline(
f'extern "C" void run(void *dipu_stream, {", ".join(paras)}) {"{"}')
run_func_code.writeline(f'extern "C" void run(void *dipu_stream, void **inputs_ptr, void **outputs_ptr) {"{"}')
with run_func_code.indent():
run_func_code.splice(func_body)
run_func_code.splice('}')
Expand Down Expand Up @@ -395,7 +383,7 @@ def check_res(a, b, graph_name):

def gen_tensor(self, prefix, tensor):
if dipu_flag:
res = f"{prefix}({tuple(tensor.shape)}, {tensor.stride()}, device='xpu:{self.device_id}', dtype={tensor.dtype})"
res = f"{prefix}({tuple(tensor.shape)}, {tensor.stride()}, device='dipu:{self.device_id}', dtype={tensor.dtype})"
else:
res = f"{prefix}({tuple(tensor.shape)}, {tensor.stride()}, device='{tensor.device.type}', dtype={tensor.dtype})"
# makes a copy of the tensor for view ops
Expand Down Expand Up @@ -913,7 +901,7 @@ def Concatenate(op_var, out_shape, out_dtype, tensors, dim):
return f"builder::Op {op_var} = builder::Concatenate({'{' + ', '.join(tensors) + '}'}, {dim});"

@staticmethod
def Softmax(op_var, out_shape, out_dtype, x, y, z):
def Softmax(op_var, out_shape, out_dtype, x, y):
return f"builder::Op {op_var} = builder::Softmax({x}, {y}, true);"

@staticmethod
Expand Down
18 changes: 5 additions & 13 deletions dicp/dicp/vendor/TopsGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,10 @@ def HardswishBackward(self, *args, **kwargs):
def Clone(self, *args, **kwargs):
return self.get_proxy(tops_op.Clone, args, kwargs)

@register_conversion(aten.copy.default)
@register_conversion([aten.copy.default, aten.copy_.default])
def Copy(self, *args, **kwargs):
return self.get_proxy(tops_op.Copy, args, kwargs)

@register_conversion(aten.copy_.default)
def Copy_(self, *args, **kwargs):
return self.get_proxy(tops_op.Copy, args, kwargs)

@register_conversion(aten.lift_fresh_copy.default)
def LiftFreshCopy(self, *args, **kwargs):
return self.get_proxy(tops_op.LiftFreshCopy, args, kwargs)
Expand Down Expand Up @@ -247,13 +243,9 @@ def Less(self, *args, **kwargs):
def LessEqual(self, *args, **kwargs):
return self.get_proxy(tops_op.LessEqual, args, kwargs)

@register_conversion(aten.eq.Tensor)
def Equal_Tensor(self, *args, **kwargs):
return self.get_proxy(tops_op.Equal_Tensor, args, kwargs)

@register_conversion(aten.eq.Scalar)
def Equal_Scalar(self, *args, **kwargs):
return self.get_proxy(tops_op.Equal_Scalar, args, kwargs)
@register_conversion([aten.eq.Tensor, aten.eq.Scalar])
def Equal(self, *args, **kwargs):
return self.get_proxy(tops_op.Equal, args, kwargs)

@register_conversion(aten.ne.Scalar)
def NotEqual(self, a, b):
Expand Down Expand Up @@ -332,7 +324,7 @@ def BatchNormBackward(*args, **kwargs):
def Softmax(self, a, dim, half_to_float):
out_shape = fx_traceback.get_current_meta()["val"].shape
dim = dim + len(out_shape) if dim < 0 else dim
return self.get_proxy(tops_op.Softmax, (a, dim, half_to_float))
return self.get_proxy(tops_op.Softmax, (a, dim))

@register_conversion(aten.mm)
def Gemm(self, *args, **kwargs):
Expand Down
34 changes: 9 additions & 25 deletions dicp/dicp/vendor/TopsGraph/tops_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,6 @@ def __init__(self, *args, **kwargs):
def __call__(self, *args, **kwargs):
new_args = args[1:]
return super().__call__(*new_args, **kwargs)


class Equal_Tensor(Operator):
def __init__(self, *args, **kwargs):
super().__init__("Equal")
self.args = args
self.kwargs = kwargs
self.torch_op = aten.eq.Tensor


class Equal_Scalar(Operator):
def __init__(self, *args, **kwargs):
super().__init__("Equal")
self.args = args
self.kwargs = kwargs
self.torch_op = aten.eq.Scalar


class Mul(Operator):
Expand Down Expand Up @@ -257,15 +241,7 @@ def __init__(self, *args, **kwargs):
super().__init__("Copy")
self.args = args
self.kwargs = kwargs
self.torch_op = torch.ops.aten.copy.default


class Copy_(Operator):
def __init__(self, *args, **kwargs):
super().__init__("Copy")
self.args = args
self.kwargs = kwargs
self.torch_op = torch.ops.aten.copy_.default
self.torch_op = torch.ops.aten.copy


class LiftFreshCopy(Operator):
Expand Down Expand Up @@ -515,6 +491,14 @@ def __init__(self, *args, **kwargs):
self.torch_op = aten.new_empty_strided.default


class Equal(Operator):
def __init__(self, *args, **kwargs):
super().__init__("Equal")
self.args = args
self.kwargs = kwargs
self.torch_op = aten.eq


class Expand(Operator):
def __init__(self, *args, **kwargs):
super().__init__("Expand")
Expand Down

0 comments on commit 5ec9c65

Please sign in to comment.