From 51978d9b38d9020b8820062d55bcf6eb6ed3d811 Mon Sep 17 00:00:00 2001 From: yaofengchen <67218893+yao-fengchen@users.noreply.github.com> Date: Wed, 6 Dec 2023 14:23:56 +0800 Subject: [PATCH] [dicp][tops] Support some ops for stable-diffusion. (#467) * Add sin, cos, erf, split. 1. Generalize MakeTuple in tops_op. 2. Generalize make_const in enflame codegen. 3. Add sin, cos, erf, split for tops. 4. Format Python code in dicp tops. * refine code * fix abs test path * clean up code of split. * adjust const op generation. * fix nullptr case in const generation. --------- Co-authored-by: jinminxi104 Co-authored-by: Reinerzhou <1768552509@qq.com> --- dicp/dicp/dynamo_bridge/op_transformer.py | 2 +- dicp/dicp/vendor/TopsGraph/codegen/enflame.py | 61 +++++++++++-------- dicp/dicp/vendor/TopsGraph/conversion.py | 20 ++++++ dicp/dicp/vendor/TopsGraph/tops_op.py | 29 +++++++-- dicp/test/op/test_cos.py | 40 ++++++++++++ dicp/test/op/test_erf.py | 40 ++++++++++++ dicp/test/op/test_sin.py | 40 ++++++++++++ dicp/test/op/test_split.py | 45 ++++++++++++++ dicp/test/tops_scripts/ops/static.ini | 4 ++ 9 files changed, 250 insertions(+), 31 deletions(-) create mode 100644 dicp/test/op/test_cos.py create mode 100644 dicp/test/op/test_erf.py create mode 100644 dicp/test/op/test_sin.py create mode 100644 dicp/test/op/test_split.py diff --git a/dicp/dicp/dynamo_bridge/op_transformer.py b/dicp/dicp/dynamo_bridge/op_transformer.py index a9166849f..e577be69f 100644 --- a/dicp/dicp/dynamo_bridge/op_transformer.py +++ b/dicp/dicp/dynamo_bridge/op_transformer.py @@ -39,7 +39,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict for idx, dim in enumerate(fake_tensor.shape): if isinstance(dim, torch.SymInt): st = dim.node.str() - if not st in self.sym_in_args: + if st not in self.sym_in_args: self.sym_in_args[st] = (proxy, idx) return proxy diff --git a/dicp/dicp/vendor/TopsGraph/codegen/enflame.py b/dicp/dicp/vendor/TopsGraph/codegen/enflame.py index 2045f20da..6682a883e 100644 --- a/dicp/dicp/vendor/TopsGraph/codegen/enflame.py +++ b/dicp/dicp/vendor/TopsGraph/codegen/enflame.py @@ -496,8 +496,7 @@ def gen_main_func(self): main_body.writeline("") for i in range(0, len(self.input_args)): itensor = self.input_args[i].meta['val'] - main_body.writeline('arg' + str(i) + ' = ' + - self.gen_random_tensor(itensor)) + main_body.writeline('arg' + str(i) + ' = ' + self.gen_random_tensor(itensor)) args = [] for i in range(len(self.input_args)): @@ -584,16 +583,18 @@ def Abs(op_var, shape, dtype, x, **kwargs_list): return f"builder::Op {op_var} = builder::Abs({x});" @staticmethod - def make_const_if_scalar(op_var, value, dtype=torch.float32, count=0): + def make_const(op_var, value, dtype=torch.float32, count=0): + assert isinstance(value, (numbers.Number, list, tuple, str)) src_code = "" - if isinstance(value, numbers.Number): - src_code = f"{cxx_type_set[dtype]} {op_var}_const_value{count} = static_cast<{cxx_type_set[dtype]}>({value});\n" - value = f"{op_var}_const{count}" - const_type = dtype if dtype != torch.float16 else torch.float32 - src_code += f"builder::Op {value} = builder::Const(hlir_builder, static_cast(&{op_var}_const_value{count}), builder::Type({{1}}, {type_set[const_type]}));\n" - if dtype == torch.float16: - src_code += f"{value} = builder::Convert({value}, builder::Type({{1}}, {type_set[dtype]}));\n" - return src_code, value + if isinstance(value, str): + return src_code, value + elif isinstance(value, numbers.Number): + src_code += f"builder::Op {op_var}_const{count} = builder::Const<{cxx_type_set[dtype]}>(hlir_builder, static_cast<{cxx_type_set[dtype]}>({value}), builder::Type({{1}}, {type_set[dtype]}));\n" + elif isinstance(value, (list, tuple)): + src_code += f"std::vector<{cxx_type_set[dtype]}> {op_var}_const_value{count} = {{{', '.join(map(str, value))}}};\n" + src_code += f"builder::Op {op_var}_const{count} = builder::Const<{cxx_type_set[dtype]}>(hlir_builder, {op_var}_const_value{count}, builder::Type({{{len(value)}}}, {type_set[dtype]}));\n" + + return src_code, f"{op_var}_const{count}" @staticmethod def make_type(op_var, dtype, shape=[1], count=0): @@ -605,7 +606,7 @@ def make_type(op_var, dtype, shape=[1], count=0): @staticmethod # TODO mul + add scaled_y should handle in conversion def Add(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype) + src_code, y = EnflameOverrides.make_const(op_var, y, dtype) src_code += f"builder::Op {op_var} = builder::Add({x}, {y});" return src_code @@ -617,21 +618,21 @@ def Convert(op_var, shape, dtype, x, y, **kwargs_list): @staticmethod def Div(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype) + src_code, y = EnflameOverrides.make_const(op_var, y, dtype) src_code += f"builder::Op {op_var} = builder::Div({x}, {y});" return src_code @staticmethod def Sub(op_var, shape, dtype, x, y, **kwargs_list): - src_code_x, x = EnflameOverrides.make_const_if_scalar(op_var, x, dtype) - src_code_y, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype) + src_code_x, x = EnflameOverrides.make_const(op_var, x, dtype) + src_code_y, y = EnflameOverrides.make_const(op_var, y, dtype) src_code = src_code_x + src_code_y src_code += f"builder::Op {op_var} = builder::Sub({x}, {y});" return src_code @staticmethod def Mul(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype) + src_code, y = EnflameOverrides.make_const(op_var, y, dtype) src_code += f"builder::Op {op_var} = builder::Mul({x}, {y});" return src_code @@ -663,19 +664,19 @@ def Less(op_var, shape, dtype, x, y, **kwargs_list): @staticmethod def Equal(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y) + src_code, y = EnflameOverrides.make_const(op_var, y) src_code += f"builder::Op {op_var} = builder::Equal({x}, {y});" return src_code @staticmethod def LessEqual(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y) + src_code, y = EnflameOverrides.make_const(op_var, y) src_code += f"builder::Op {op_var} = builder::LessEqual({x}, {y});" return src_code @staticmethod def NotEqual(op_var, shape, dtype, data_type, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar( + src_code, y = EnflameOverrides.make_const( op_var, y, data_type) src_code += f"builder::Op {op_var} = builder::NotEqual({x}, {y});" return src_code @@ -690,7 +691,7 @@ def Neg(op_var, shape, dtype, x, **kwargs_list): @staticmethod def Pow(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype) + src_code, y = EnflameOverrides.make_const(op_var, y, dtype) src_code += f"builder::Op {op_var} = builder::Pow({x}, {y});" return src_code @@ -702,10 +703,22 @@ def Exp(op_var, shape, dtype, x, **kwargs_list): def Sqrt(op_var, shape, dtype, x, **kwargs_list): return f"builder::Op {op_var} = builder::Sqrt({x});" + @staticmethod + def Sin(op_var, shape, dtype, x, **kwargs_list): + return f"builder::Op {op_var} = builder::Sin({x});" + + @staticmethod + def Cos(op_var, shape, dtype, x, **kwargs_list): + return f"builder::Op {op_var} = builder::Cos({x});" + @staticmethod def Relu(op_var, shape, dtype, x, **kwargs_list): return f"builder::Op {op_var} = builder::Relu({x});" + @staticmethod + def Erf(op_var, shape, dtype, x, **kwargs_list): + return f"builder::Op {op_var} = builder::Erf({x});" + @staticmethod def Sigmoid(op_var, shape, dtype, x, **kwargs_list): return f"builder::Op {op_var} = builder::Sigmoid({x});" @@ -802,14 +815,14 @@ def Expand(op_var, shape, dtype, x, new_shape, **kwargs_list): @staticmethod def Squeeze(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar( + src_code, y = EnflameOverrides.make_const( op_var, y, torch.int64) src_code += f"builder::Op {op_var} = builder::Squeeze({x}, {y});" return src_code @staticmethod def Unsqueeze(op_var, shape, dtype, x, y, **kwargs_list): - src_code, y = EnflameOverrides.make_const_if_scalar( + src_code, y = EnflameOverrides.make_const( op_var, y, torch.int64) src_code += f"builder::Op {op_var} = builder::Unsqueeze({x}, {y});" return src_code @@ -847,9 +860,9 @@ def SliceInDim(op_var, shape, dtype, x, dim, start, end, step, **kwargs_list): @staticmethod def SliceScatter(op_var, shape, dtype, x, y, dim, start, end, step, **kwargs_list): - src_code_index, op_start_index = EnflameOverrides.make_const_if_scalar( + src_code_index, op_start_index = EnflameOverrides.make_const( op_var, 0, torch.int64, 0) - src_code_index_dim, op_start_index_dim = EnflameOverrides.make_const_if_scalar( + src_code_index_dim, op_start_index_dim = EnflameOverrides.make_const( op_var, start, torch.int64, 1) src_code = src_code_index + src_code_index_dim src_code += f"builder::Op {op_var} = builder::DynamicUpdateSlice({x}, {y}, {{{', '.join([op_start_index_dim if i == dim else op_start_index for i in range(len(shape))])}}});" diff --git a/dicp/dicp/vendor/TopsGraph/conversion.py b/dicp/dicp/vendor/TopsGraph/conversion.py index 0fdfa746c..fa9ccbfbc 100644 --- a/dicp/dicp/vendor/TopsGraph/conversion.py +++ b/dicp/dicp/vendor/TopsGraph/conversion.py @@ -148,10 +148,30 @@ def Rsqrt(self, *args, **kwargs): def Exp(self, *args, **kwargs): return self.get_proxy(tops_op.Exp, args, kwargs) + @register_conversion(aten.sin) + def Sin(self, *args, **kwargs): + return self.get_proxy(tops_op.Sin, args, kwargs) + + @register_conversion(aten.cos) + def Cos(self, *args, **kwargs): + return self.get_proxy(tops_op.Cos, args, kwargs) + @register_conversion(aten.relu) def Relu(self, *args, **kwargs): return self.get_proxy(tops_op.Relu, args, kwargs) + @register_conversion(aten.erf) + def Erf(self, *args, **kwargs): + return self.get_proxy(tops_op.Erf, args, kwargs) + + @register_conversion(aten.split.Tensor) + def Split(self, a, size, dim=0, **kwargs): + in_shape = a.node.meta["val"].shape + dim = dim % len(in_shape) + sections = (in_shape[dim] + size - 1) // size + splits = (self.get_proxy(tops_op.SliceInDim, (a, dim, i * size, min((i + 1) * size, in_shape[dim]), 1)) for i in range(sections)) + return self.get_proxy(tops_op.MakeTuple, tuple(splits)) + @register_conversion(aten.sum) def ReduceSum(self, a, *args, **kwargs): in_dtype = a.node.meta["val"].dtype diff --git a/dicp/dicp/vendor/TopsGraph/tops_op.py b/dicp/dicp/vendor/TopsGraph/tops_op.py index 7199bbf24..dc89a407f 100644 --- a/dicp/dicp/vendor/TopsGraph/tops_op.py +++ b/dicp/dicp/vendor/TopsGraph/tops_op.py @@ -142,6 +142,20 @@ def __init__(self, a): self.torch_op = aten.exp +class Sin(Operator): + def __init__(self, a): + super().__init__("Sin") + self.a = a + self.torch_op = aten.sin + + +class Cos(Operator): + def __init__(self, a): + super().__init__("Cos") + self.a = a + self.torch_op = aten.cos + + class Relu(Operator): def __init__(self, a): super().__init__("Relu") @@ -149,6 +163,13 @@ def __init__(self, a): self.torch_op = aten.relu +class Erf(Operator): + def __init__(self, a): + super().__init__("Erf") + self.a = a + self.torch_op = aten.erf + + class ReduceSum(Operator): def __init__(self, *args, **kwargs): super().__init__("ReduceSum") @@ -742,12 +763,8 @@ def __init__(self, a, b): super().__init__("MakeTuple") self.torch_op = torch.empty_like - def __call__(self, a, b): - if hasattr(a, 'meta'): - a = a.meta['val'] - if hasattr(b, 'meta'): - b = b.meta['val'] - return a, b + def __call__(self, *args): + return (arg.meta["val"] if hasattr(arg, "meta") else arg for arg in args) class XlaGather(Operator): diff --git a/dicp/test/op/test_cos.py b/dicp/test/op/test_cos.py new file mode 100644 index 000000000..cfbc4d98b --- /dev/null +++ b/dicp/test/op/test_cos.py @@ -0,0 +1,40 @@ +import pytest +from ..common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, a): + res_default = torch.ops.aten.cos.default(a) + return res_default + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + + +class TestCos(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_torch_cos(self, sizes, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + input1 = torch.randn(size, dtype=dtype) + + dicp_input1 = input1.to(device) + + output = model(input1) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input1) + + assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True) diff --git a/dicp/test/op/test_erf.py b/dicp/test/op/test_erf.py new file mode 100644 index 000000000..f649309ad --- /dev/null +++ b/dicp/test/op/test_erf.py @@ -0,0 +1,40 @@ +import pytest +from ..common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, a): + res_default = torch.ops.aten.erf.default(a) + return res_default + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + + +class TestErf(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_torch_erf(self, sizes, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + input1 = torch.randn(size, dtype=dtype) + + dicp_input1 = input1.to(device) + + output = model(input1) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input1) + + assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True) diff --git a/dicp/test/op/test_sin.py b/dicp/test/op/test_sin.py new file mode 100644 index 000000000..e37cfaf7c --- /dev/null +++ b/dicp/test/op/test_sin.py @@ -0,0 +1,40 @@ +import pytest +from ..common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, a): + res_default = torch.ops.aten.sin.default(a) + return res_default + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + + +class TestSin(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_torch_sin(self, sizes, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + input1 = torch.randn(size, dtype=dtype) + + dicp_input1 = input1.to(device) + + output = model(input1) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input1) + + assert torch.allclose(output, dicp_output.cpu(), equal_nan=True) diff --git a/dicp/test/op/test_split.py b/dicp/test/op/test_split.py new file mode 100644 index 000000000..bbb737fcf --- /dev/null +++ b/dicp/test/op/test_split.py @@ -0,0 +1,45 @@ +import pytest +from ..common.utils import ( + torch, + dynamo, + parse_args, + compile_model, + get_device, + Size, + update_dynamo_config, +) + + +class OpModule(torch.nn.Module): + def forward(self, x, split_size_or_sections, dim): + res_Tensor = torch.ops.aten.split.Tensor(x, split_size_or_sections, dim) + for i in range(len(res_Tensor)): + res_Tensor[i] = res_Tensor[i] + 1.0 + return res_Tensor + + +model = OpModule() +args = parse_args() +compiled_model = compile_model(model, args.backend, args.dynamic) + + +class TestSplit(): + @pytest.mark.parametrize("dtype", [torch.float32]) + @pytest.mark.parametrize("sizes", [Size((5, 3), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))]) + @pytest.mark.parametrize("split_size_or_sections", [1, 2]) + @pytest.mark.parametrize("dim", [0, -1]) + @pytest.mark.parametrize("compiled_model", compiled_model) + def test_torch_split(self, sizes, split_size_or_sections, dim, dtype, compiled_model): + device = get_device() + size = sizes.dynamic if compiled_model.dynamic else sizes.static + input1 = torch.randn(size, dtype=dtype) + + dicp_input1 = input1.to(device) + + output = model(input1, split_size_or_sections, dim) + dynamo.reset() + update_dynamo_config(compiled_model.dynamic) + dicp_output = compiled_model.model(dicp_input1, split_size_or_sections, dim) + + for i, item in enumerate(output): + assert torch.allclose(item, dicp_output[i].cpu(), equal_nan=True) diff --git a/dicp/test/tops_scripts/ops/static.ini b/dicp/test/tops_scripts/ops/static.ini index 8a45c1c3b..69d12d2a9 100644 --- a/dicp/test/tops_scripts/ops/static.ini +++ b/dicp/test/tops_scripts/ops/static.ini @@ -20,10 +20,12 @@ python_files = test__adpative_avg_pool2d_backward.py test_convolution.py test_copy_.py test_copy.py + test_cos.py test_div.py test_embedding.py test_empty_like.py test_eq.py + test_erf.py test_exp.py test_expand.py test_fill.py @@ -63,9 +65,11 @@ python_files = test__adpative_avg_pool2d_backward.py test_scatter.py test_select.py test_sigmoid.py + test_sin.py test_slice_scatter.py test_slice.py test_sqrt.py + test_split.py test_square.py test_squeeze.py test_sub.py