From 47d0f8acab017b737da8257252942c773562abde Mon Sep 17 00:00:00 2001 From: yaofengchen Date: Mon, 27 Nov 2023 03:03:13 +0000 Subject: [PATCH] Add sin, cos, erf, split. 1. Generalize MakeTuple in tops_op. 2. Generalize make_const in enflame codegen. 3. Add sin, cos, erf, split for tops. 4. Format Python code in dicp tops. --- dicp/dicp/dynamo_bridge/op_transformer.py | 2 +- dicp/dicp/vendor/TopsGraph/codegen/enflame.py | 61 +++++++++++++------ dicp/dicp/vendor/TopsGraph/conversion.py | 23 +++++++ dicp/dicp/vendor/TopsGraph/tops_op.py | 41 +++++++++++-- dicp/test/op/test_cos.py | 40 ++++++++++++ dicp/test/op/test_erf.py | 40 ++++++++++++ dicp/test/op/test_sin.py | 40 ++++++++++++ dicp/test/op/test_split.py | 45 ++++++++++++++ dicp/test/tops_scripts/ops/static.ini | 4 ++ 9 files changed, 270 insertions(+), 26 deletions(-) create mode 100644 dicp/test/op/test_cos.py create mode 100644 dicp/test/op/test_erf.py create mode 100644 dicp/test/op/test_sin.py create mode 100644 dicp/test/op/test_split.py diff --git a/dicp/dicp/dynamo_bridge/op_transformer.py b/dicp/dicp/dynamo_bridge/op_transformer.py index a9166849f..e577be69f 100644 --- a/dicp/dicp/dynamo_bridge/op_transformer.py +++ b/dicp/dicp/dynamo_bridge/op_transformer.py @@ -39,7 +39,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict for idx, dim in enumerate(fake_tensor.shape): if isinstance(dim, torch.SymInt): st = dim.node.str() - if not st in self.sym_in_args: + if st not in self.sym_in_args: self.sym_in_args[st] = (proxy, idx) return proxy diff --git a/dicp/dicp/vendor/TopsGraph/codegen/enflame.py b/dicp/dicp/vendor/TopsGraph/codegen/enflame.py index 627950ff2..f16b1fada 100644 --- a/dicp/dicp/vendor/TopsGraph/codegen/enflame.py +++ b/dicp/dicp/vendor/TopsGraph/codegen/enflame.py @@ -494,8 +494,7 @@ def gen_main_func(self): main_body.writeline("") for i in range(0, len(self.input_args)): itensor = self.input_args[i].meta['val'] - main_body.writeline('arg' + str(i) + ' = ' + - self.gen_random_tensor(itensor)) + main_body.writeline('arg' + str(i) + ' = ' + self.gen_random_tensor(itensor)) args = [] for i in range(len(self.input_args)): @@ -578,13 +577,19 @@ def Abs(op_var, shape, dtype, x, **kwargs_list): return f"builder::Op {op_var} = builder::Abs({x});" @staticmethod - def make_const_if_scalar(op_var, value, dtype=torch.float32, count=0): + def make_const(op_var, value, dtype=torch.float32, count=0): + assert isinstance(value, (numbers.Number, list, tuple, str)) src_code = "" - if isinstance(value, numbers.Number): + if isinstance(value, str): + return src_code, value + elif isinstance(value, numbers.Number): src_code = f"{cxx_type_set[dtype]} {op_var}_const_value{count} = static_cast<{cxx_type_set[dtype]}>({value});\n" - value = f"{op_var}_const{count}" - src_code += f"builder::Type {op_var}_const_type{count}({{1}}, {type_set[dtype]});\n" - src_code += f"builder::Op {value} = builder::Const(hlir_builder, static_cast(&{op_var}_const_value{count}), {op_var}_const_type{count});\n" + src_code += f"builder::Type {op_var}_const_type{count}({{{1}}}, {type_set[dtype]});\n" + elif isinstance(value, (list, tuple)): + src_code = f"std::vector<{cxx_type_set[dtype]}> {op_var}_const_value{count} = {{{', '.join(map(str, value))}}};\n" + src_code += f"builder::Type {op_var}_const_type{count}({{{len(value)}}}, {type_set[dtype]});\n" + value = f"{op_var}_const{count}" + src_code += f"builder::Op {value} = builder::Const(hlir_builder, static_cast(&{op_var}_const_value{count}), {op_var}_const_type{count});\n" return src_code, value @staticmethod @@ -597,7 +602,7 @@ def make_type(op_var, dtype, shape=[1], count=0): @staticmethod # TODO mul + add scaled_y should handle in conversion def Add(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype) + src_code, y = EnflameOverrides.make_const(op_var, y, dtype) src_code += f"builder::Op {op_var} = builder::Add({x}, {y});" return src_code @@ -609,19 +614,19 @@ def Convert(op_var, shape, dtype, x, y, **kwargs_list): @staticmethod def Div(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype) + src_code, y = EnflameOverrides.make_const(op_var, y, dtype) src_code += f"builder::Op {op_var} = builder::Div({x}, {y});" return src_code @staticmethod def Sub(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype) + src_code, y = EnflameOverrides.make_const(op_var, y, dtype) src_code += f"builder::Op {op_var} = builder::Sub({x}, {y});" return src_code @staticmethod def Mul(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype) + src_code, y = EnflameOverrides.make_const(op_var, y, dtype) src_code += f"builder::Op {op_var} = builder::Mul({x}, {y});" return src_code @@ -653,19 +658,19 @@ def Less(op_var, shape, dtype, x, y, **kwargs_list): @staticmethod def Equal(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y) + src_code, y = EnflameOverrides.make_const(op_var, y) src_code += f"builder::Op {op_var} = builder::Equal({x}, {y});" return src_code @staticmethod def LessEqual(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y) + src_code, y = EnflameOverrides.make_const(op_var, y) src_code += f"builder::Op {op_var} = builder::LessEqual({x}, {y});" return src_code @staticmethod def NotEqual(op_var, shape, dtype, data_type, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar( + src_code, y = EnflameOverrides.make_const( op_var, y, data_type) src_code += f"builder::Op {op_var} = builder::NotEqual({x}, {y});" return src_code @@ -680,7 +685,7 @@ def Neg(op_var, shape, dtype, x, **kwargs_list): @staticmethod def Pow(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y) + src_code, y = EnflameOverrides.make_const(op_var, y) src_code += f"builder::Op {op_var} = builder::Pow({x}, {y});" return src_code @@ -692,10 +697,28 @@ def Exp(op_var, shape, dtype, x, **kwargs_list): def Sqrt(op_var, shape, dtype, x, **kwargs_list): return f"builder::Op {op_var} = builder::Sqrt({x});" + @staticmethod + def Sin(op_var, shape, dtype, x, **kwargs_list): + return f"builder::Op {op_var} = builder::Sin({x});" + + @staticmethod + def Cos(op_var, shape, dtype, x, **kwargs_list): + return f"builder::Op {op_var} = builder::Cos({x});" + @staticmethod def Relu(op_var, shape, dtype, x, **kwargs_list): return f"builder::Op {op_var} = builder::Relu({x});" + @staticmethod + def Erf(op_var, shape, dtype, x, **kwargs_list): + return f"builder::Op {op_var} = builder::Erf({x});" + + @staticmethod + def Split(op_var, shape, dtype, x, split, axis=0, num_outputs=0, **kwargs_list): + src_code, split = EnflameOverrides.make_const(op_var, split, torch.int64) + src_code += f"builder::Op {op_var} = builder::GeneralSplit({x}, {split}, {axis});" + return src_code + @staticmethod def Sigmoid(op_var, shape, dtype, x, **kwargs_list): return f"builder::Op {op_var} = builder::Sigmoid({x});" @@ -791,14 +814,14 @@ def Expand(op_var, shape, dtype, x, new_shape, **kwargs_list): @staticmethod def Squeeze(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar( + src_code, y = EnflameOverrides.make_const( op_var, y, torch.int64) src_code += f"builder::Op {op_var} = builder::Squeeze({x}, {y});" return src_code @staticmethod def Unsqueeze(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar( + src_code, y = EnflameOverrides.make_const( op_var, y, torch.int64) src_code += f"builder::Op {op_var} = builder::Unsqueeze({x}, {y});" return src_code @@ -836,9 +859,9 @@ def SliceInDim(op_var, shape, dtype, x, dim, start, end, step, **kwargs_list): @staticmethod def SliceScatter(op_var, shape, dtype, x, y, dim, start, end, step, **kwargs_list): - src_code_index, op_start_index = EnflameOverrides.make_const_if_scalar( + src_code_index, op_start_index = EnflameOverrides.make_const( op_var, 0, torch.int64, 0) - src_code_index_dim, op_start_index_dim = EnflameOverrides.make_const_if_scalar( + src_code_index_dim, op_start_index_dim = EnflameOverrides.make_const( op_var, start, torch.int64, 1) src_code = src_code_index + src_code_index_dim src_code += f"builder::Op {op_var} = builder::DynamicUpdateSlice({x}, {y}, {{{', '.join([op_start_index_dim if i == dim else op_start_index for i in range(len(shape))])}}});" diff --git a/dicp/dicp/vendor/TopsGraph/conversion.py b/dicp/dicp/vendor/TopsGraph/conversion.py index 30269ddd6..e11951894 100644 --- a/dicp/dicp/vendor/TopsGraph/conversion.py +++ b/dicp/dicp/vendor/TopsGraph/conversion.py @@ -148,10 +148,33 @@ def Rsqrt(self, *args, **kwargs): def Exp(self, *args, **kwargs): return self.get_proxy(tops_op.Exp, args, kwargs) + @register_conversion(aten.sin) + def Sin(self, *args, **kwargs): + return self.get_proxy(tops_op.Sin, args, kwargs) + + @register_conversion(aten.cos) + def Cos(self, *args, **kwargs): + return self.get_proxy(tops_op.Cos, args, kwargs) + @register_conversion(aten.relu) def Relu(self, *args, **kwargs): return self.get_proxy(tops_op.Relu, args, kwargs) + @register_conversion(aten.erf) + def Erf(self, *args, **kwargs): + return self.get_proxy(tops_op.Erf, args, kwargs) + + @register_conversion(aten.split.Tensor) + def Split(self, a, size, dim=0, **kwargs): + in_shape = a.node.meta["val"].shape + dim = dim % len(in_shape) + sections = in_shape[dim] // size if in_shape[dim] % size == 0 else in_shape[dim] // size + 1 + + def get_real_end(end): + return end if end <= in_shape[dim] else in_shape[dim] + splits = (self.get_proxy(tops_op.SliceInDim, (a, dim, i * size, get_real_end((i + 1) * size), 1)) for i in range(sections)) + return self.get_proxy(tops_op.MakeTuple, tuple(splits)) + @register_conversion(aten.sum) def ReduceSum(self, a, *args, **kwargs): in_dtype = a.node.meta["val"].dtype diff --git a/dicp/dicp/vendor/TopsGraph/tops_op.py b/dicp/dicp/vendor/TopsGraph/tops_op.py index d2a9ae0a4..e5858f8e0 100644 --- a/dicp/dicp/vendor/TopsGraph/tops_op.py +++ b/dicp/dicp/vendor/TopsGraph/tops_op.py @@ -142,6 +142,20 @@ def __init__(self, a): self.torch_op = aten.exp +class Sin(Operator): + def __init__(self, a): + super().__init__("Sin") + self.a = a + self.torch_op = aten.sin + + +class Cos(Operator): + def __init__(self, a): + super().__init__("Cos") + self.a = a + self.torch_op = aten.cos + + class Relu(Operator): def __init__(self, a): super().__init__("Relu") @@ -149,6 +163,25 @@ def __init__(self, a): self.torch_op = aten.relu +class Erf(Operator): + def __init__(self, a): + super().__init__("Erf") + self.a = a + self.torch_op = aten.erf + + +class Split(Operator): + def __init__(self, *args, **kwargs): + super().__init__("Split") + self.args = args + self.kwargs = kwargs + self.torch_op = aten.split + + def __call__(self, *args, **kwargs): + new_args = args[:3] + return super().__call__(*new_args, **kwargs) + + class ReduceSum(Operator): def __init__(self, *args, **kwargs): super().__init__("ReduceSum") @@ -741,12 +774,8 @@ def __init__(self, a, b): super().__init__("MakeTuple") self.torch_op = torch.empty_like - def __call__(self, a, b): - if hasattr(a, 'meta'): - a = a.meta['val'] - if hasattr(b, 'meta'): - b = b.meta['val'] - return a, b + def __call__(self, *args): + return (arg.meta["val"] if hasattr(arg, "meta") else arg for arg in args) class XlaGather(Operator): diff --git a/dicp/test/op/test_cos.py b/dicp/test/op/test_cos.py new file mode 100644 index 000000000..72a36e53f --- /dev/null +++ b/dicp/test/op/test_cos.py @@ -0,0 +1,40 @@ +import pytest +from common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, a): + res_default = torch.ops.aten.cos.default(a) + return res_default + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + + +class TestCos(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_torch_cos(self, sizes, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + input1 = torch.randn(size, dtype=dtype) + + dicp_input1 = input1.to(device) + + output = model(input1) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input1) + + assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True) diff --git a/dicp/test/op/test_erf.py b/dicp/test/op/test_erf.py new file mode 100644 index 000000000..19ebb06cf --- /dev/null +++ b/dicp/test/op/test_erf.py @@ -0,0 +1,40 @@ +import pytest +from common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, a): + res_default = torch.ops.aten.erf.default(a) + return res_default + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + + +class TestErf(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_torch_erf(self, sizes, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + input1 = torch.randn(size, dtype=dtype) + + dicp_input1 = input1.to(device) + + output = model(input1) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input1) + + assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True) diff --git a/dicp/test/op/test_sin.py b/dicp/test/op/test_sin.py new file mode 100644 index 000000000..a62d7315a --- /dev/null +++ b/dicp/test/op/test_sin.py @@ -0,0 +1,40 @@ +import pytest +from common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, a): + res_default = torch.ops.aten.sin.default(a) + return res_default + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + + +class TestSin(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_torch_sin(self, sizes, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + input1 = torch.randn(size, dtype=dtype) + + dicp_input1 = input1.to(device) + + output = model(input1) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input1) + + assert torch.allclose(output, dicp_output.cpu(), equal_nan=True) diff --git a/dicp/test/op/test_split.py b/dicp/test/op/test_split.py new file mode 100644 index 000000000..841f6bcf7 --- /dev/null +++ b/dicp/test/op/test_split.py @@ -0,0 +1,45 @@ +import pytest +from common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, x, split_size_or_sections, dim): + res_Tensor = torch.ops.aten.split.Tensor(x, split_size_or_sections, dim) + for i in range(len(res_Tensor)): + res_Tensor[i] = res_Tensor[i] + 1.0 + return res_Tensor + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + + +class TestSplit(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size((5, 3), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) + @pytest.mark.parametrize("split_size_or_sections", [1, 2]) + @pytest.mark.parametrize("dim", [0, -1]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_torch_split(self, sizes, split_size_or_sections, dim, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + input1 = torch.randn(size, dtype=dtype) + + dicp_input1 = input1.to(device) + + output = model(input1, split_size_or_sections, dim) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input1, split_size_or_sections, dim) + + for i, item in enumerate(output): + assert torch.allclose(item, dicp_output[i].cpu(), equal_nan=True) diff --git a/dicp/test/tops_scripts/ops/static.ini b/dicp/test/tops_scripts/ops/static.ini index 8a45c1c3b..69d12d2a9 100644 --- a/dicp/test/tops_scripts/ops/static.ini +++ b/dicp/test/tops_scripts/ops/static.ini @@ -20,10 +20,12 @@ python_files = test__adpative_avg_pool2d_backward.py test_convolution.py test_copy_.py test_copy.py + test_cos.py test_div.py test_embedding.py test_empty_like.py test_eq.py + test_erf.py test_exp.py test_expand.py test_fill.py @@ -63,9 +65,11 @@ python_files = test__adpative_avg_pool2d_backward.py test_scatter.py test_select.py test_sigmoid.py + test_sin.py test_slice_scatter.py test_slice.py test_sqrt.py + test_split.py test_square.py test_squeeze.py test_sub.py