Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.47】API转换 103-124 #346

Merged
merged 9 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
332 changes: 316 additions & 16 deletions paconvert/api_mapping.json

Large diffs are not rendered by default.

106 changes: 33 additions & 73 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,20 @@ def get_paddle_class_nodes(self, func, args, kwargs):
return "delete"


class AtleastMatcher(BaseMatcher):
def get_paddle_nodes(self, args, kwargs):
new_args = self.parse_args(args)
if new_args[0][0] == "(" and new_args[0][-1] == ")":
new_args[0] = new_args[0][1:-1]
if new_args[0][0] == "[" and new_args[0][-1] == "]":
new_args[0] = new_args[0][1:-1]
new_kwargs = self.parse_kwargs(kwargs)
code = "{}({})".format(
self.get_paddle_api(), self.args_and_kwargs_to_str(new_args, new_kwargs)
)
return ast.parse(code).body


class UnchangeMatcher(BaseMatcher):
def get_paddle_class_attribute_nodes(self, node):
return "unchange"
Expand Down Expand Up @@ -1549,7 +1563,11 @@ def generate_code(self, kwargs):
if len(kwargs) == 0:
code = f"str({self.paddleClass}.dtype)"
else:
code = f"{self.paddleClass}.astype({kwargs['dtype']})"
# For torch.nn.Module.type, torch.nn.Module.type use torch.Tensor.type
if "dst_type" in kwargs:
code = f"{self.paddleClass}.astype({kwargs['dst_type']})"
else:
code = f"{self.paddleClass}.astype({kwargs['dtype']})"
return code


Expand Down Expand Up @@ -2111,61 +2129,21 @@ def get_paddle_nodes(self, args, kwargs):


class TensorToMatcher(BaseMatcher):
def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
def to(self, *args, **kwargs):
args_list = ["x", "y", "non_blocking", "copy", "memory_format"]
new_kwargs = {}
for i, node in enumerate(args):
k = args_list[i]
new_kwargs[k] = node
for node in kwargs:
v = kwargs[node]
new_kwargs[node] = v
kwargs = new_kwargs
if not kwargs:
return self
elif "tensor" in kwargs:
return paddle.cast(self, "{}.dtype".format(kwargs["tensor"]))
elif "dtype" in kwargs:
return paddle.cast(self, "{}".format(kwargs["dtype"]))
elif "device" in kwargs and "dtype" not in kwargs:
return self
elif kwargs:
if "y" not in kwargs and "x" in kwargs:
if isinstance(kwargs["x"], paddle.dtype):
dtype = kwargs["x"]
elif isinstance(kwargs["x"], str) and kwargs["x"] not in ['cpu', 'cuda', 'ipu', 'xpu']:
dtype = kwargs["x"]
elif isinstance(kwargs["x"], paddle.Tensor):
dtype = kwargs["x"].dtype
else:
dtype = self.dtype
return paddle.cast(self, dtype)

elif "y" in kwargs and "x" in kwargs:
if isinstance(kwargs["x"], paddle.dtype):
dtype = kwargs["x"]
elif isinstance(kwargs["x"], str):
if x not in ['cpu', 'cuda', 'ipu', 'xpu']:
dtype = kwargs["x"]
else:
dtype = kwargs["y"] if isinstance(kwargs["y"], str) else self.dtype
else:
dtype = kwargs["x"]
return paddle.cast(self, dtype)
else:
return self

setattr(paddle.Tensor, 'to', to)
"""
def get_paddle_nodes(self, args, kwargs):
new_args = self.parse_args(args)
new_kwargs = self.parse_kwargs(kwargs)
if new_kwargs is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

什么时候会return None,return None时一般是不支持,应该继续return None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API有bug时返回None,已修改return None

return None
if "copy" in new_kwargs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是不是考虑json配置下,直接kwargs_change成空字符串

Copy link
Contributor Author

@co63oc co63oc Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kwargs_change 使用GenericMatcher,GenericMatcher 中 dtype, device转换不同不适合使用,然后只有参数 copy ,不需要使用 kwargs_change

new_kwargs.pop("copy")
if "memory_format" in new_kwargs:
new_kwargs.pop("memory_format")
if "non_blocking" in new_kwargs:
new_kwargs["blocking"] = "not " + new_kwargs.pop("non_blocking").strip("()")
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有个 Reverse* 的Matcher可以参考,看是不是统一成一个功能更强的Matcher更好

Copy link
Contributor Author

@co63oc co63oc Dec 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是 ReverseMomentumMatcher ,输入是数值,类型不同
image

code = "{}.to({})".format(
self.paddleClass, self.args_and_kwargs_to_str(new_args, new_kwargs)
)
return CODE_TEMPLATE

def get_paddle_class_nodes(self, func, args, kwargs):
self.write_aux_code()
return "unchange"
return ast.parse(code).body


class TensorRequiresGradMatcher(BaseMatcher):
Expand Down Expand Up @@ -2964,24 +2942,6 @@ def get_paddle_nodes(self, args, kwargs):
return ast.parse(code).body


class HypotMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "input" not in kwargs:
kwargs["input"] = self.paddleClass

API_TEMPLATE = textwrap.dedent(
"""
paddle.pow({}**2 + {}**2, 1/2)
"""
)
code = API_TEMPLATE.format(kwargs["input"], kwargs["other"])

if "out" in kwargs and kwargs["out"] != "None":
code = "paddle.assign({}, output={})".format(code, kwargs["out"])

return code


class TensorViewMatcher(BaseMatcher):
def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
Expand Down
57 changes: 33 additions & 24 deletions tests/test_Tensor_diagonal_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,48 +23,57 @@ def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.zeros(3, 3)
src = torch.ones(3)
input = torch.arange(6.0).reshape((2, 3))
src = torch.ones((2,))
result = input.diagonal_scatter(src)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.zeros(3, 3)
src = torch.ones(3)
input = torch.arange(6.0).reshape((2, 3))
src = torch.ones((2,))
result = input.diagonal_scatter(src=src)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.zeros(3, 3)
src = torch.ones(3)
result = input.diagonal_scatter(src=src, offset=0, dim1=-2)
input = torch.arange(6.0).reshape((2, 3))
src = torch.ones((2,))
result = input.diagonal_scatter(offset=0, src=src, dim2=1, dim1=-2)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.arange(6.0).reshape((2, 3))
src = torch.ones((2,))
result = input.diagonal_scatter(src=src, offset=0, dim1=-2, dim2=1)
co63oc marked this conversation as resolved.
Show resolved Hide resolved
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.arange(6.0).reshape((2, 3))
src = torch.ones((2,))
result = input.diagonal_scatter(src, 0, -2, 1)
"""
)
obj.run(pytorch_code, ["result"])
12 changes: 12 additions & 0 deletions tests/test_Tensor_hypot.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,15 @@ def test_case_4():
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.tensor([1., 2, 3])
b = torch.tensor([4., 5, 6])
result = a.hypot(b+1)
"""
)
obj.run(pytorch_code, ["result"])
78 changes: 78 additions & 0 deletions tests/test_Tensor_hypot_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import textwrap

from apibase import APIBase

obj = APIBase("torch.Tensor.hypot_")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.tensor([1., 2, 3])
b = torch.tensor([4., 5, 6])
result = a.hypot_(b)
"""
)
obj.run(pytorch_code, ["result", "a"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.tensor([1., 2, 3])
b = torch.tensor([4., 5, 6])
result = a.hypot_(other=b)
"""
)
obj.run(pytorch_code, ["result", "a"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.tensor([-1., 2, 3])
b = torch.tensor([4., 5, 6])
result = a.hypot_(other=b)
"""
)
obj.run(pytorch_code, ["result", "a"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.tensor([1., 2, 3])
b = torch.tensor([4., 5, 6])
result = a.hypot_(other=b+1)
"""
)
obj.run(pytorch_code, ["result", "a"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.tensor([1., 2, 3])
b = torch.tensor([4., 5, 6])
result = a.hypot_(b+1)
"""
)
obj.run(pytorch_code, ["result", "a"])
82 changes: 82 additions & 0 deletions tests/test_Tensor_index_fill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.Tensor.index_fill")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.eye(2, 4)
indices = torch.tensor([0, 1])
value = -1
result = x.index_fill(0, indices, value)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
indices = torch.tensor([0, 1])
value = -1
result = torch.eye(3, 4).index_fill(1, indices, value)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
indices = torch.tensor([0, 1])
dim = 0
value = -1
result = torch.eye(3, 4).index_fill(index=indices, dim=dim, value=value)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
indices = torch.tensor([0, 3])
dim = 0
value = -1
result = torch.eye(6, 4).index_fill(dim=dim, index=indices, value=value)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
indices = torch.tensor([0, 3])
value = -1
result = torch.eye(3, 4).index_fill(1, indices, value)
"""
)
obj.run(pytorch_code, ["result"])
Loading