Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dicp][tops] Run llama_finetune success. #440

Merged
merged 8 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions dicp/dicp/vendor/TopsGraph/codegen/enflame.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def call_function(self, name, target, args, kwargs):
return

def output(self, name, target, args, kwargs):
self.inplace_dict = kwargs
for i in range(0, len(args[0])):
self.output_args.append(args[0][i])

Expand Down Expand Up @@ -413,14 +414,16 @@ def gen_call_func(self):
for i in range(len(self.output_args)):
if not isinstance(self.output_args[i], type(None)):
bufs.append(self.output_args[i].name)
if self.output_args[i] not in self.input_args:
if self.output_args[i] not in self.input_args and bufs[-1] not in self.inplace_dict.keys():
otensor = self.output_args[i].meta['val']
call_body.writeline(bufs[-1] + " = " + self.gen_empty_tensor(otensor))
else:
bufs.append("buf" + str(i))
none_bufs.append(bufs[-1])
call_body.writeline(
bufs[-1] + " = " + ("empty_strided((), ())"))
for i in range(len(bufs) - len(self.inplace_dict), len(bufs)):
bufs[i] = self.inplace_dict[bufs[i]]

call_body.writeline("")

Expand Down Expand Up @@ -473,7 +476,7 @@ def gen_call_func(self):
if dipu_flag:
call_body.writeline(f"torch_dipu.current_stream({self.device_id}).synchronize()")

call_body.writeline(f"return ({', '.join(bufs)})")
call_body.writeline(f"return ({', '.join(bufs[:len(bufs)-len(self.inplace_dict)])})")

call_func = IndentedBuffer()
call_func.writeline("def call(args):")
Expand Down Expand Up @@ -568,6 +571,10 @@ def Clone(op_var, shape, dtype, x, **kwargs_list):
def Copy(op_var, shape, dtype, x, y, **kwargs_list):
return f"builder::Op {op_var} = {y};"

@staticmethod
def Copy_(op_var, shape, dtype, x, y, **kwargs_list):
return f"builder::Op {op_var} = {y};"

@staticmethod
def LiftFreshCopy(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = {x};"
Expand All @@ -582,8 +589,10 @@ def make_const_if_scalar(op_var, value, dtype=torch.float32, count=0):
if isinstance(value, numbers.Number):
src_code = f"{cxx_type_set[dtype]} {op_var}_const_value{count} = static_cast<{cxx_type_set[dtype]}>({value});\n"
value = f"{op_var}_const{count}"
src_code += f"builder::Type {op_var}_const_type{count}({{1}}, {type_set[dtype]});\n"
src_code += f"builder::Op {value} = builder::Const(hlir_builder, static_cast<void *>(&{op_var}_const_value{count}), {op_var}_const_type{count});\n"
const_type = dtype if dtype != torch.float16 else torch.float32
src_code += f"builder::Op {value} = builder::Const(hlir_builder, static_cast<void *>(&{op_var}_const_value{count}), builder::Type({{1}}, {type_set[const_type]}));\n"
if dtype == torch.float16:
src_code += f"{value} = builder::Convert({value}, builder::Type({{1}}, {type_set[dtype]}));\n"
return src_code, value

@staticmethod
Expand Down Expand Up @@ -614,7 +623,9 @@ def Div(op_var, shape, dtype, x, y, **kwargs_list):

@staticmethod
def Sub(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code_x, x = EnflameOverrides.make_const_if_scalar(op_var, x, dtype)
src_code_y, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code = src_code_x + src_code_y
src_code += f"builder::Op {op_var} = builder::Sub({x}, {y});"
return src_code

Expand Down Expand Up @@ -679,7 +690,7 @@ def Neg(op_var, shape, dtype, x, **kwargs_list):

@staticmethod
def Pow(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y)
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Pow({x}, {y});"
return src_code

Expand Down Expand Up @@ -761,7 +772,8 @@ def Full(op_var, out_shape, out_dtype, size, value, **kwargs_list):

@staticmethod
def FullLike(op_var, shape, dtype, x, value, **kwargs_list):
return f"builder::Op {op_var} = builder::FullLike({x}, {value});"
src_code = f"builder::Op {op_var} = builder::FullLike({x}, {value}, {type_set[dtype]}, {{{str(shape).split('[')[-1].split(']')[0]}}});"
return src_code

@staticmethod
def Transpose(op_var, shape, dtype, x, permution=[0, 1], **kwargs_list):
Expand Down
27 changes: 15 additions & 12 deletions dicp/dicp/vendor/TopsGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,14 @@ def Clone(self, *args, **kwargs):
return self.get_proxy(tops_op.Clone, args, kwargs)

# Copy_ is only validated for inplace copy of input parameters in optimizer, be careful about other cases.
@register_conversion([aten.copy.default, aten.copy_.default])
@register_conversion(aten.copy.default)
def Copy(self, *args, **kwargs):
return self.get_proxy(tops_op.Copy, args, kwargs)

@register_conversion(aten.copy_.default)
def Copy_(self, *args, **kwargs):
return self.get_proxy(tops_op.Copy_, args, kwargs)

@register_conversion(aten.lift_fresh_copy.default)
def LiftFreshCopy(self, *args, **kwargs):
return self.get_proxy(tops_op.LiftFreshCopy, args, kwargs)
Expand Down Expand Up @@ -403,11 +407,11 @@ def Slice(self, a, dim=0, start=0, end=-1, step=1, **kwargs):
@register_conversion(aten.slice_scatter.default)
def SliceScatter(self, a, b, dim=0, start=0, end=-1, step=1):
operand_shape = a.node.meta["val"].shape
dim = dim % len(operand_shape)
start = 0 if start is None else start
end = operand_shape[dim] if end is None else end
end = end % operand_shape[dim] if end < operand_shape[dim] else operand_shape[dim]
assert end == operand_shape[dim] and step == 1, "limited support"
if end != operand_shape[dim]:
Warning(f"SliceScatter encounter unsupported end value: {end}, this will affect precision!")
if step != 1:
Warning(f"SliceScatter encounter unsupported step value: {step}, this will affect precision!")
return self.get_proxy(tops_op.SliceScatter, (a, b, dim, start, end, step))

@register_conversion(aten.select.int)
Expand Down Expand Up @@ -435,18 +439,17 @@ def OnesLike(self, *args, **kwargs):

@register_conversion(aten.scalar_tensor.default)
def Scalar(self, a, **kwargs):
if "dtype" in kwargs:
real_dtype = kwargs["dtype"]
if real_dtype not in (torch.int64, torch.float32):
kwargs["dtype"] = torch.float32
scalar = self.get_proxy(tops_op.Scalar, (a,), kwargs)
return self.get_proxy(tops_op.Convert(), (scalar, real_dtype))
out_dtype = fx_traceback.get_current_meta()['val'].dtype
if out_dtype is torch.float16:
kwargs["dtype"] = torch.float32
scalar = self.get_proxy(tops_op.Scalar, (a,), kwargs)
return self.get_proxy(tops_op.Convert(), (scalar, out_dtype))
return self.get_proxy(tops_op.Scalar, (a,), kwargs)

@register_conversion(aten.embedding)
def Embedding(self, *args, **kwargs):
idx_rank = len(args[1].node.meta['val'].shape)
return self.get_proxy(tops_op.XlaGather, (*args, [idx_rank,], [0,], [0,], idx_rank,
return self.get_proxy(tops_op.XlaGather, (args[0], args[1], [idx_rank,], [0,], [0,], idx_rank,
[1, args[0].node.meta['val'].shape[1]]))

@register_conversion(prims.convert_element_type)
Expand Down
25 changes: 25 additions & 0 deletions dicp/dicp/vendor/TopsGraph/opset_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,28 @@
from dicp.vendor.TopsGraph.conversion import tops_patterns, aten_patterns_cls_list, tops_patterns_cls_list


class HandleInplaceCopyPass():
def transform(self, gm: torch.fx.GraphModule):
nodes = list(gm.graph.nodes)
last_node = nodes[-1]
assert last_node.op == "output"

origin_outputs = list(last_node.args[0])
inplace_outputs = []
inplace_dict = {}
for node in reversed(nodes):
if node.op not in ["placeholder", "output"] and not isinstance(node.target, str):
if node.target.name() == "Copy_":
if node.args[0].op == "placeholder" and node.args[0].name not in inplace_dict.values():
inplace_outputs.append(node.args[1])
inplace_dict[node.args[1].name] = node.args[0].name

assert len(last_node.kwargs) == 0
last_node._Node__update_args_kwargs((origin_outputs + inplace_outputs, ), inplace_dict)

return gm


def topsgraph_opset_transform(
gm: torch.fx.GraphModule,
):
Expand All @@ -26,4 +48,7 @@ def topsgraph_opset_transform(
gm = BackendPatternMatcherTransformer(
tops_patterns, tops_patterns_cls_list).transform(gm)

# handle inplace copy operation: get inplace copy args to update outputs.
gm = HandleInplaceCopyPass().transform(gm)

return gm
8 changes: 8 additions & 0 deletions dicp/dicp/vendor/TopsGraph/tops_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,14 @@ def __init__(self, *args, **kwargs):
self.torch_op = torch.ops.aten.copy


class Copy_(Operator):
def __init__(self, *args, **kwargs):
super().__init__("Copy_")
self.args = args
self.kwargs = kwargs
self.torch_op = torch.ops.aten.copy_


class LiftFreshCopy(Operator):
def __init__(self, *args, **kwargs):
super().__init__("LiftFreshCopy")
Expand Down
50 changes: 50 additions & 0 deletions dicp/test/op/test_inplace_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest
from common.utils import (
torch,
dynamo,
parse_args,
compile_model,
get_device,
Size,
update_dynamo_config,
)


class OpModule(torch.nn.Module):
def forward(self, a, b):
res_Tensor = torch.ops.aten.add.Tensor(a, b)
a.copy_(res_Tensor)
return res_Tensor


model = OpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestInplaceCopy():
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_add(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
input1 = torch.ones(size, dtype=dtype)
input2 = torch.ones(size, dtype=dtype)

dicp_input1 = input1.to(device)
dicp_input2 = input2.to(device)

output = model(input1, input2)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input1, dicp_input2)

for i, item in enumerate(output):
if isinstance(item, torch.Tensor):
assert torch.allclose(item, dicp_output[i].cpu(), equal_nan=True)
else:
assert item == dicp_output[i]

# Confirm the correctness of the inplace copy result.
assert torch.allclose(dicp_input1.cpu(), dicp_output.cpu(), equal_nan=True)