Skip to content

Commit

Permalink
[dicp][tops] Support some ops for stable-diffusion. (DeepLink-org#467)
Browse files Browse the repository at this point in the history
* Add sin, cos, erf, split.

1. Generalize MakeTuple in tops_op.
2. Generalize make_const in enflame codegen.
3. Add sin, cos, erf, split for tops.
4. Format Python code in dicp tops.

* refine code

* fix abs test path

* clean up code of split.

* adjust const op generation.

* fix nullptr case in const generation.

---------

Co-authored-by: jinminxi104 <[email protected]>
Co-authored-by: Reinerzhou <[email protected]>
  • Loading branch information
3 people authored and ustclight-sls committed Dec 8, 2023
1 parent a190a80 commit 51978d9
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 31 deletions.
2 changes: 1 addition & 1 deletion dicp/dicp/dynamo_bridge/op_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
for idx, dim in enumerate(fake_tensor.shape):
if isinstance(dim, torch.SymInt):
st = dim.node.str()
if not st in self.sym_in_args:
if st not in self.sym_in_args:
self.sym_in_args[st] = (proxy, idx)
return proxy

Expand Down
61 changes: 37 additions & 24 deletions dicp/dicp/vendor/TopsGraph/codegen/enflame.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,7 @@ def gen_main_func(self):
main_body.writeline("")
for i in range(0, len(self.input_args)):
itensor = self.input_args[i].meta['val']
main_body.writeline('arg' + str(i) + ' = ' +
self.gen_random_tensor(itensor))
main_body.writeline('arg' + str(i) + ' = ' + self.gen_random_tensor(itensor))

args = []
for i in range(len(self.input_args)):
Expand Down Expand Up @@ -584,16 +583,18 @@ def Abs(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Abs({x});"

@staticmethod
def make_const_if_scalar(op_var, value, dtype=torch.float32, count=0):
def make_const(op_var, value, dtype=torch.float32, count=0):
assert isinstance(value, (numbers.Number, list, tuple, str))
src_code = ""
if isinstance(value, numbers.Number):
src_code = f"{cxx_type_set[dtype]} {op_var}_const_value{count} = static_cast<{cxx_type_set[dtype]}>({value});\n"
value = f"{op_var}_const{count}"
const_type = dtype if dtype != torch.float16 else torch.float32
src_code += f"builder::Op {value} = builder::Const(hlir_builder, static_cast<void *>(&{op_var}_const_value{count}), builder::Type({{1}}, {type_set[const_type]}));\n"
if dtype == torch.float16:
src_code += f"{value} = builder::Convert({value}, builder::Type({{1}}, {type_set[dtype]}));\n"
return src_code, value
if isinstance(value, str):
return src_code, value
elif isinstance(value, numbers.Number):
src_code += f"builder::Op {op_var}_const{count} = builder::Const<{cxx_type_set[dtype]}>(hlir_builder, static_cast<{cxx_type_set[dtype]}>({value}), builder::Type({{1}}, {type_set[dtype]}));\n"
elif isinstance(value, (list, tuple)):
src_code += f"std::vector<{cxx_type_set[dtype]}> {op_var}_const_value{count} = {{{', '.join(map(str, value))}}};\n"
src_code += f"builder::Op {op_var}_const{count} = builder::Const<{cxx_type_set[dtype]}>(hlir_builder, {op_var}_const_value{count}, builder::Type({{{len(value)}}}, {type_set[dtype]}));\n"

return src_code, f"{op_var}_const{count}"

@staticmethod
def make_type(op_var, dtype, shape=[1], count=0):
Expand All @@ -605,7 +606,7 @@ def make_type(op_var, dtype, shape=[1], count=0):
@staticmethod
# TODO mul + add scaled_y should handle in conversion
def Add(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Add({x}, {y});"
return src_code

Expand All @@ -617,21 +618,21 @@ def Convert(op_var, shape, dtype, x, y, **kwargs_list):

@staticmethod
def Div(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Div({x}, {y});"
return src_code

@staticmethod
def Sub(op_var, shape, dtype, x, y, **kwargs_list):
src_code_x, x = EnflameOverrides.make_const_if_scalar(op_var, x, dtype)
src_code_y, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code_x, x = EnflameOverrides.make_const(op_var, x, dtype)
src_code_y, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code = src_code_x + src_code_y
src_code += f"builder::Op {op_var} = builder::Sub({x}, {y});"
return src_code

@staticmethod
def Mul(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Mul({x}, {y});"
return src_code

Expand Down Expand Up @@ -663,19 +664,19 @@ def Less(op_var, shape, dtype, x, y, **kwargs_list):

@staticmethod
def Equal(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y)
src_code, y = EnflameOverrides.make_const(op_var, y)
src_code += f"builder::Op {op_var} = builder::Equal({x}, {y});"
return src_code

@staticmethod
def LessEqual(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y)
src_code, y = EnflameOverrides.make_const(op_var, y)
src_code += f"builder::Op {op_var} = builder::LessEqual({x}, {y});"
return src_code

@staticmethod
def NotEqual(op_var, shape, dtype, data_type, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(
src_code, y = EnflameOverrides.make_const(
op_var, y, data_type)
src_code += f"builder::Op {op_var} = builder::NotEqual({x}, {y});"
return src_code
Expand All @@ -690,7 +691,7 @@ def Neg(op_var, shape, dtype, x, **kwargs_list):

@staticmethod
def Pow(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Pow({x}, {y});"
return src_code

Expand All @@ -702,10 +703,22 @@ def Exp(op_var, shape, dtype, x, **kwargs_list):
def Sqrt(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Sqrt({x});"

@staticmethod
def Sin(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Sin({x});"

@staticmethod
def Cos(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Cos({x});"

@staticmethod
def Relu(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Relu({x});"

@staticmethod
def Erf(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Erf({x});"

@staticmethod
def Sigmoid(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Sigmoid({x});"
Expand Down Expand Up @@ -802,14 +815,14 @@ def Expand(op_var, shape, dtype, x, new_shape, **kwargs_list):

@staticmethod
def Squeeze(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(
src_code, y = EnflameOverrides.make_const(
op_var, y, torch.int64)
src_code += f"builder::Op {op_var} = builder::Squeeze({x}, {y});"
return src_code

@staticmethod
def Unsqueeze(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(
src_code, y = EnflameOverrides.make_const(
op_var, y, torch.int64)
src_code += f"builder::Op {op_var} = builder::Unsqueeze({x}, {y});"
return src_code
Expand Down Expand Up @@ -847,9 +860,9 @@ def SliceInDim(op_var, shape, dtype, x, dim, start, end, step, **kwargs_list):

@staticmethod
def SliceScatter(op_var, shape, dtype, x, y, dim, start, end, step, **kwargs_list):
src_code_index, op_start_index = EnflameOverrides.make_const_if_scalar(
src_code_index, op_start_index = EnflameOverrides.make_const(
op_var, 0, torch.int64, 0)
src_code_index_dim, op_start_index_dim = EnflameOverrides.make_const_if_scalar(
src_code_index_dim, op_start_index_dim = EnflameOverrides.make_const(
op_var, start, torch.int64, 1)
src_code = src_code_index + src_code_index_dim
src_code += f"builder::Op {op_var} = builder::DynamicUpdateSlice({x}, {y}, {{{', '.join([op_start_index_dim if i == dim else op_start_index for i in range(len(shape))])}}});"
Expand Down
20 changes: 20 additions & 0 deletions dicp/dicp/vendor/TopsGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,30 @@ def Rsqrt(self, *args, **kwargs):
def Exp(self, *args, **kwargs):
return self.get_proxy(tops_op.Exp, args, kwargs)

@register_conversion(aten.sin)
def Sin(self, *args, **kwargs):
return self.get_proxy(tops_op.Sin, args, kwargs)

@register_conversion(aten.cos)
def Cos(self, *args, **kwargs):
return self.get_proxy(tops_op.Cos, args, kwargs)

@register_conversion(aten.relu)
def Relu(self, *args, **kwargs):
return self.get_proxy(tops_op.Relu, args, kwargs)

@register_conversion(aten.erf)
def Erf(self, *args, **kwargs):
return self.get_proxy(tops_op.Erf, args, kwargs)

@register_conversion(aten.split.Tensor)
def Split(self, a, size, dim=0, **kwargs):
in_shape = a.node.meta["val"].shape
dim = dim % len(in_shape)
sections = (in_shape[dim] + size - 1) // size
splits = (self.get_proxy(tops_op.SliceInDim, (a, dim, i * size, min((i + 1) * size, in_shape[dim]), 1)) for i in range(sections))
return self.get_proxy(tops_op.MakeTuple, tuple(splits))

@register_conversion(aten.sum)
def ReduceSum(self, a, *args, **kwargs):
in_dtype = a.node.meta["val"].dtype
Expand Down
29 changes: 23 additions & 6 deletions dicp/dicp/vendor/TopsGraph/tops_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,34 @@ def __init__(self, a):
self.torch_op = aten.exp


class Sin(Operator):
def __init__(self, a):
super().__init__("Sin")
self.a = a
self.torch_op = aten.sin


class Cos(Operator):
def __init__(self, a):
super().__init__("Cos")
self.a = a
self.torch_op = aten.cos


class Relu(Operator):
def __init__(self, a):
super().__init__("Relu")
self.a = a
self.torch_op = aten.relu


class Erf(Operator):
def __init__(self, a):
super().__init__("Erf")
self.a = a
self.torch_op = aten.erf


class ReduceSum(Operator):
def __init__(self, *args, **kwargs):
super().__init__("ReduceSum")
Expand Down Expand Up @@ -742,12 +763,8 @@ def __init__(self, a, b):
super().__init__("MakeTuple")
self.torch_op = torch.empty_like

def __call__(self, a, b):
if hasattr(a, 'meta'):
a = a.meta['val']
if hasattr(b, 'meta'):
b = b.meta['val']
return a, b
def __call__(self, *args):
return (arg.meta["val"] if hasattr(arg, "meta") else arg for arg in args)


class XlaGather(Operator):
Expand Down
40 changes: 40 additions & 0 deletions dicp/test/op/test_cos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from ..common.utils import (
torch,
dynamo,
parse_args,
compile_model,
get_device,
Size,
update_dynamo_config,
)


class OpModule(torch.nn.Module):
def forward(self, a):
res_default = torch.ops.aten.cos.default(a)
return res_default


model = OpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestCos():
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_cos(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
input1 = torch.randn(size, dtype=dtype)

dicp_input1 = input1.to(device)

output = model(input1)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input1)

assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True)
40 changes: 40 additions & 0 deletions dicp/test/op/test_erf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from ..common.utils import (
torch,
dynamo,
parse_args,
compile_model,
get_device,
Size,
update_dynamo_config,
)


class OpModule(torch.nn.Module):
def forward(self, a):
res_default = torch.ops.aten.erf.default(a)
return res_default


model = OpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestErf():
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_erf(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
input1 = torch.randn(size, dtype=dtype)

dicp_input1 = input1.to(device)

output = model(input1)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input1)

assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True)
40 changes: 40 additions & 0 deletions dicp/test/op/test_sin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from ..common.utils import (
torch,
dynamo,
parse_args,
compile_model,
get_device,
Size,
update_dynamo_config,
)


class OpModule(torch.nn.Module):
def forward(self, a):
res_default = torch.ops.aten.sin.default(a)
return res_default


model = OpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestSin():
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_sin(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
input1 = torch.randn(size, dtype=dtype)

dicp_input1 = input1.to(device)

output = model(input1)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input1)

assert torch.allclose(output, dicp_output.cpu(), equal_nan=True)
Loading

0 comments on commit 51978d9

Please sign in to comment.