Skip to content

Commit

Permalink
Add sin, cos, erf, split.
Browse files Browse the repository at this point in the history
1. Generalize MakeTuple in tops_op.
2. Generalize make_const in enflame codegen.
3. Add sin, cos, erf, split for tops.
4. Format Python code in dicp tops.
  • Loading branch information
yao-fengchen committed Nov 27, 2023
1 parent 870e796 commit 47d0f8a
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 26 deletions.
2 changes: 1 addition & 1 deletion dicp/dicp/dynamo_bridge/op_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
for idx, dim in enumerate(fake_tensor.shape):
if isinstance(dim, torch.SymInt):
st = dim.node.str()
if not st in self.sym_in_args:
if st not in self.sym_in_args:
self.sym_in_args[st] = (proxy, idx)
return proxy

Expand Down
61 changes: 42 additions & 19 deletions dicp/dicp/vendor/TopsGraph/codegen/enflame.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,7 @@ def gen_main_func(self):
main_body.writeline("")
for i in range(0, len(self.input_args)):
itensor = self.input_args[i].meta['val']
main_body.writeline('arg' + str(i) + ' = ' +
self.gen_random_tensor(itensor))
main_body.writeline('arg' + str(i) + ' = ' + self.gen_random_tensor(itensor))

args = []
for i in range(len(self.input_args)):
Expand Down Expand Up @@ -578,13 +577,19 @@ def Abs(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Abs({x});"

@staticmethod
def make_const_if_scalar(op_var, value, dtype=torch.float32, count=0):
def make_const(op_var, value, dtype=torch.float32, count=0):
assert isinstance(value, (numbers.Number, list, tuple, str))
src_code = ""
if isinstance(value, numbers.Number):
if isinstance(value, str):
return src_code, value
elif isinstance(value, numbers.Number):
src_code = f"{cxx_type_set[dtype]} {op_var}_const_value{count} = static_cast<{cxx_type_set[dtype]}>({value});\n"
value = f"{op_var}_const{count}"
src_code += f"builder::Type {op_var}_const_type{count}({{1}}, {type_set[dtype]});\n"
src_code += f"builder::Op {value} = builder::Const(hlir_builder, static_cast<void *>(&{op_var}_const_value{count}), {op_var}_const_type{count});\n"
src_code += f"builder::Type {op_var}_const_type{count}({{{1}}}, {type_set[dtype]});\n"
elif isinstance(value, (list, tuple)):
src_code = f"std::vector<{cxx_type_set[dtype]}> {op_var}_const_value{count} = {{{', '.join(map(str, value))}}};\n"
src_code += f"builder::Type {op_var}_const_type{count}({{{len(value)}}}, {type_set[dtype]});\n"
value = f"{op_var}_const{count}"
src_code += f"builder::Op {value} = builder::Const(hlir_builder, static_cast<void *>(&{op_var}_const_value{count}), {op_var}_const_type{count});\n"
return src_code, value

@staticmethod
Expand All @@ -597,7 +602,7 @@ def make_type(op_var, dtype, shape=[1], count=0):
@staticmethod
# TODO mul + add scaled_y should handle in conversion
def Add(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Add({x}, {y});"
return src_code

Expand All @@ -609,19 +614,19 @@ def Convert(op_var, shape, dtype, x, y, **kwargs_list):

@staticmethod
def Div(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Div({x}, {y});"
return src_code

@staticmethod
def Sub(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Sub({x}, {y});"
return src_code

@staticmethod
def Mul(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Mul({x}, {y});"
return src_code

Expand Down Expand Up @@ -653,19 +658,19 @@ def Less(op_var, shape, dtype, x, y, **kwargs_list):

@staticmethod
def Equal(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y)
src_code, y = EnflameOverrides.make_const(op_var, y)
src_code += f"builder::Op {op_var} = builder::Equal({x}, {y});"
return src_code

@staticmethod
def LessEqual(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y)
src_code, y = EnflameOverrides.make_const(op_var, y)
src_code += f"builder::Op {op_var} = builder::LessEqual({x}, {y});"
return src_code

@staticmethod
def NotEqual(op_var, shape, dtype, data_type, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(
src_code, y = EnflameOverrides.make_const(
op_var, y, data_type)
src_code += f"builder::Op {op_var} = builder::NotEqual({x}, {y});"
return src_code
Expand All @@ -680,7 +685,7 @@ def Neg(op_var, shape, dtype, x, **kwargs_list):

@staticmethod
def Pow(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y)
src_code, y = EnflameOverrides.make_const(op_var, y)
src_code += f"builder::Op {op_var} = builder::Pow({x}, {y});"
return src_code

Expand All @@ -692,10 +697,28 @@ def Exp(op_var, shape, dtype, x, **kwargs_list):
def Sqrt(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Sqrt({x});"

@staticmethod
def Sin(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Sin({x});"

@staticmethod
def Cos(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Cos({x});"

@staticmethod
def Relu(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Relu({x});"

@staticmethod
def Erf(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Erf({x});"

@staticmethod
def Split(op_var, shape, dtype, x, split, axis=0, num_outputs=0, **kwargs_list):
src_code, split = EnflameOverrides.make_const(op_var, split, torch.int64)
src_code += f"builder::Op {op_var} = builder::GeneralSplit({x}, {split}, {axis});"
return src_code

@staticmethod
def Sigmoid(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Sigmoid({x});"
Expand Down Expand Up @@ -791,14 +814,14 @@ def Expand(op_var, shape, dtype, x, new_shape, **kwargs_list):

@staticmethod
def Squeeze(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(
src_code, y = EnflameOverrides.make_const(
op_var, y, torch.int64)
src_code += f"builder::Op {op_var} = builder::Squeeze({x}, {y});"
return src_code

@staticmethod
def Unsqueeze(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(
src_code, y = EnflameOverrides.make_const(
op_var, y, torch.int64)
src_code += f"builder::Op {op_var} = builder::Unsqueeze({x}, {y});"
return src_code
Expand Down Expand Up @@ -836,9 +859,9 @@ def SliceInDim(op_var, shape, dtype, x, dim, start, end, step, **kwargs_list):

@staticmethod
def SliceScatter(op_var, shape, dtype, x, y, dim, start, end, step, **kwargs_list):
src_code_index, op_start_index = EnflameOverrides.make_const_if_scalar(
src_code_index, op_start_index = EnflameOverrides.make_const(
op_var, 0, torch.int64, 0)
src_code_index_dim, op_start_index_dim = EnflameOverrides.make_const_if_scalar(
src_code_index_dim, op_start_index_dim = EnflameOverrides.make_const(
op_var, start, torch.int64, 1)
src_code = src_code_index + src_code_index_dim
src_code += f"builder::Op {op_var} = builder::DynamicUpdateSlice({x}, {y}, {{{', '.join([op_start_index_dim if i == dim else op_start_index for i in range(len(shape))])}}});"
Expand Down
23 changes: 23 additions & 0 deletions dicp/dicp/vendor/TopsGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,33 @@ def Rsqrt(self, *args, **kwargs):
def Exp(self, *args, **kwargs):
return self.get_proxy(tops_op.Exp, args, kwargs)

@register_conversion(aten.sin)
def Sin(self, *args, **kwargs):
return self.get_proxy(tops_op.Sin, args, kwargs)

@register_conversion(aten.cos)
def Cos(self, *args, **kwargs):
return self.get_proxy(tops_op.Cos, args, kwargs)

@register_conversion(aten.relu)
def Relu(self, *args, **kwargs):
return self.get_proxy(tops_op.Relu, args, kwargs)

@register_conversion(aten.erf)
def Erf(self, *args, **kwargs):
return self.get_proxy(tops_op.Erf, args, kwargs)

@register_conversion(aten.split.Tensor)
def Split(self, a, size, dim=0, **kwargs):
in_shape = a.node.meta["val"].shape
dim = dim % len(in_shape)
sections = in_shape[dim] // size if in_shape[dim] % size == 0 else in_shape[dim] // size + 1

def get_real_end(end):
return end if end <= in_shape[dim] else in_shape[dim]
splits = (self.get_proxy(tops_op.SliceInDim, (a, dim, i * size, get_real_end((i + 1) * size), 1)) for i in range(sections))
return self.get_proxy(tops_op.MakeTuple, tuple(splits))

@register_conversion(aten.sum)
def ReduceSum(self, a, *args, **kwargs):
in_dtype = a.node.meta["val"].dtype
Expand Down
41 changes: 35 additions & 6 deletions dicp/dicp/vendor/TopsGraph/tops_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,46 @@ def __init__(self, a):
self.torch_op = aten.exp


class Sin(Operator):
def __init__(self, a):
super().__init__("Sin")
self.a = a
self.torch_op = aten.sin


class Cos(Operator):
def __init__(self, a):
super().__init__("Cos")
self.a = a
self.torch_op = aten.cos


class Relu(Operator):
def __init__(self, a):
super().__init__("Relu")
self.a = a
self.torch_op = aten.relu


class Erf(Operator):
def __init__(self, a):
super().__init__("Erf")
self.a = a
self.torch_op = aten.erf


class Split(Operator):
def __init__(self, *args, **kwargs):
super().__init__("Split")
self.args = args
self.kwargs = kwargs
self.torch_op = aten.split

def __call__(self, *args, **kwargs):
new_args = args[:3]
return super().__call__(*new_args, **kwargs)


class ReduceSum(Operator):
def __init__(self, *args, **kwargs):
super().__init__("ReduceSum")
Expand Down Expand Up @@ -741,12 +774,8 @@ def __init__(self, a, b):
super().__init__("MakeTuple")
self.torch_op = torch.empty_like

def __call__(self, a, b):
if hasattr(a, 'meta'):
a = a.meta['val']
if hasattr(b, 'meta'):
b = b.meta['val']
return a, b
def __call__(self, *args):
return (arg.meta["val"] if hasattr(arg, "meta") else arg for arg in args)


class XlaGather(Operator):
Expand Down
40 changes: 40 additions & 0 deletions dicp/test/op/test_cos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from common.utils import (
torch,
dynamo,
parse_args,
compile_model,
get_device,
Size,
update_dynamo_config,
)


class OpModule(torch.nn.Module):
def forward(self, a):
res_default = torch.ops.aten.cos.default(a)
return res_default


model = OpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestCos():
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_cos(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
input1 = torch.randn(size, dtype=dtype)

dicp_input1 = input1.to(device)

output = model(input1)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input1)

assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True)
40 changes: 40 additions & 0 deletions dicp/test/op/test_erf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from common.utils import (
torch,
dynamo,
parse_args,
compile_model,
get_device,
Size,
update_dynamo_config,
)


class OpModule(torch.nn.Module):
def forward(self, a):
res_default = torch.ops.aten.erf.default(a)
return res_default


model = OpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestErf():
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_erf(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
input1 = torch.randn(size, dtype=dtype)

dicp_input1 = input1.to(device)

output = model(input1)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input1)

assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True)
40 changes: 40 additions & 0 deletions dicp/test/op/test_sin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from common.utils import (
torch,
dynamo,
parse_args,
compile_model,
get_device,
Size,
update_dynamo_config,
)


class OpModule(torch.nn.Module):
def forward(self, a):
res_default = torch.ops.aten.sin.default(a)
return res_default


model = OpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestSin():
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_sin(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
input1 = torch.randn(size, dtype=dtype)

dicp_input1 = input1.to(device)

output = model(input1)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input1)

assert torch.allclose(output, dicp_output.cpu(), equal_nan=True)
Loading

0 comments on commit 47d0f8a

Please sign in to comment.