-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.47】API转换 103-124 #346
Changes from all commits
31b37e0
8889264
16c73a4
83a94a9
f4b789e
8252a5a
f59a4d1
c1b69fb
ab3281f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -169,6 +169,20 @@ def get_paddle_class_nodes(self, func, args, kwargs): | |
return "delete" | ||
|
||
|
||
class AtleastMatcher(BaseMatcher): | ||
def get_paddle_nodes(self, args, kwargs): | ||
new_args = self.parse_args(args) | ||
if new_args[0][0] == "(" and new_args[0][-1] == ")": | ||
new_args[0] = new_args[0][1:-1] | ||
if new_args[0][0] == "[" and new_args[0][-1] == "]": | ||
new_args[0] = new_args[0][1:-1] | ||
new_kwargs = self.parse_kwargs(kwargs) | ||
code = "{}({})".format( | ||
self.get_paddle_api(), self.args_and_kwargs_to_str(new_args, new_kwargs) | ||
) | ||
return ast.parse(code).body | ||
|
||
|
||
class UnchangeMatcher(BaseMatcher): | ||
def get_paddle_class_attribute_nodes(self, node): | ||
return "unchange" | ||
|
@@ -1549,7 +1563,11 @@ def generate_code(self, kwargs): | |
if len(kwargs) == 0: | ||
code = f"str({self.paddleClass}.dtype)" | ||
else: | ||
code = f"{self.paddleClass}.astype({kwargs['dtype']})" | ||
# For torch.nn.Module.type, torch.nn.Module.type use torch.Tensor.type | ||
if "dst_type" in kwargs: | ||
code = f"{self.paddleClass}.astype({kwargs['dst_type']})" | ||
else: | ||
code = f"{self.paddleClass}.astype({kwargs['dtype']})" | ||
return code | ||
|
||
|
||
|
@@ -2111,61 +2129,21 @@ def get_paddle_nodes(self, args, kwargs): | |
|
||
|
||
class TensorToMatcher(BaseMatcher): | ||
def generate_aux_code(self): | ||
CODE_TEMPLATE = textwrap.dedent( | ||
""" | ||
def to(self, *args, **kwargs): | ||
args_list = ["x", "y", "non_blocking", "copy", "memory_format"] | ||
new_kwargs = {} | ||
for i, node in enumerate(args): | ||
k = args_list[i] | ||
new_kwargs[k] = node | ||
for node in kwargs: | ||
v = kwargs[node] | ||
new_kwargs[node] = v | ||
kwargs = new_kwargs | ||
if not kwargs: | ||
return self | ||
elif "tensor" in kwargs: | ||
return paddle.cast(self, "{}.dtype".format(kwargs["tensor"])) | ||
elif "dtype" in kwargs: | ||
return paddle.cast(self, "{}".format(kwargs["dtype"])) | ||
elif "device" in kwargs and "dtype" not in kwargs: | ||
return self | ||
elif kwargs: | ||
if "y" not in kwargs and "x" in kwargs: | ||
if isinstance(kwargs["x"], paddle.dtype): | ||
dtype = kwargs["x"] | ||
elif isinstance(kwargs["x"], str) and kwargs["x"] not in ['cpu', 'cuda', 'ipu', 'xpu']: | ||
dtype = kwargs["x"] | ||
elif isinstance(kwargs["x"], paddle.Tensor): | ||
dtype = kwargs["x"].dtype | ||
else: | ||
dtype = self.dtype | ||
return paddle.cast(self, dtype) | ||
|
||
elif "y" in kwargs and "x" in kwargs: | ||
if isinstance(kwargs["x"], paddle.dtype): | ||
dtype = kwargs["x"] | ||
elif isinstance(kwargs["x"], str): | ||
if x not in ['cpu', 'cuda', 'ipu', 'xpu']: | ||
dtype = kwargs["x"] | ||
else: | ||
dtype = kwargs["y"] if isinstance(kwargs["y"], str) else self.dtype | ||
else: | ||
dtype = kwargs["x"] | ||
return paddle.cast(self, dtype) | ||
else: | ||
return self | ||
|
||
setattr(paddle.Tensor, 'to', to) | ||
""" | ||
def get_paddle_nodes(self, args, kwargs): | ||
new_args = self.parse_args(args) | ||
new_kwargs = self.parse_kwargs(kwargs) | ||
if new_kwargs is None: | ||
return None | ||
if "copy" in new_kwargs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个是不是考虑json配置下,直接kwargs_change成空字符串 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kwargs_change 使用GenericMatcher,GenericMatcher 中 dtype, device转换不同不适合使用,然后只有参数 copy ,不需要使用 kwargs_change |
||
new_kwargs.pop("copy") | ||
if "memory_format" in new_kwargs: | ||
new_kwargs.pop("memory_format") | ||
if "non_blocking" in new_kwargs: | ||
new_kwargs["blocking"] = "not " + new_kwargs.pop("non_blocking").strip("()") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有个 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
code = "{}.to({})".format( | ||
self.paddleClass, self.args_and_kwargs_to_str(new_args, new_kwargs) | ||
) | ||
return CODE_TEMPLATE | ||
|
||
def get_paddle_class_nodes(self, func, args, kwargs): | ||
self.write_aux_code() | ||
return "unchange" | ||
return ast.parse(code).body | ||
|
||
|
||
class TensorRequiresGradMatcher(BaseMatcher): | ||
|
@@ -2964,24 +2942,6 @@ def get_paddle_nodes(self, args, kwargs): | |
return ast.parse(code).body | ||
|
||
|
||
class HypotMatcher(BaseMatcher): | ||
def generate_code(self, kwargs): | ||
if "input" not in kwargs: | ||
kwargs["input"] = self.paddleClass | ||
|
||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
paddle.pow({}**2 + {}**2, 1/2) | ||
""" | ||
) | ||
code = API_TEMPLATE.format(kwargs["input"], kwargs["other"]) | ||
|
||
if "out" in kwargs and kwargs["out"] != "None": | ||
code = "paddle.assign({}, output={})".format(code, kwargs["out"]) | ||
|
||
return code | ||
|
||
|
||
class TensorViewMatcher(BaseMatcher): | ||
def generate_aux_code(self): | ||
CODE_TEMPLATE = textwrap.dedent( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import textwrap | ||
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.Tensor.hypot_") | ||
|
||
|
||
def test_case_1(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
a = torch.tensor([1., 2, 3]) | ||
b = torch.tensor([4., 5, 6]) | ||
result = a.hypot_(b) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result", "a"]) | ||
|
||
|
||
def test_case_2(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
a = torch.tensor([1., 2, 3]) | ||
b = torch.tensor([4., 5, 6]) | ||
result = a.hypot_(other=b) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result", "a"]) | ||
|
||
|
||
def test_case_3(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
a = torch.tensor([-1., 2, 3]) | ||
b = torch.tensor([4., 5, 6]) | ||
result = a.hypot_(other=b) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result", "a"]) | ||
|
||
|
||
def test_case_4(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
a = torch.tensor([1., 2, 3]) | ||
b = torch.tensor([4., 5, 6]) | ||
result = a.hypot_(other=b+1) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result", "a"]) | ||
|
||
|
||
def test_case_5(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
a = torch.tensor([1., 2, 3]) | ||
b = torch.tensor([4., 5, 6]) | ||
result = a.hypot_(b+1) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result", "a"]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import textwrap | ||
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.Tensor.index_fill") | ||
|
||
|
||
def test_case_1(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
x = torch.eye(2, 4) | ||
indices = torch.tensor([0, 1]) | ||
value = -1 | ||
result = x.index_fill(0, indices, value) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_2(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
indices = torch.tensor([0, 1]) | ||
value = -1 | ||
result = torch.eye(3, 4).index_fill(1, indices, value) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_3(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
indices = torch.tensor([0, 1]) | ||
dim = 0 | ||
value = -1 | ||
result = torch.eye(3, 4).index_fill(index=indices, dim=dim, value=value) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_4(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
indices = torch.tensor([0, 3]) | ||
dim = 0 | ||
value = -1 | ||
result = torch.eye(6, 4).index_fill(dim=dim, index=indices, value=value) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_5(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
indices = torch.tensor([0, 3]) | ||
value = -1 | ||
result = torch.eye(3, 4).index_fill(1, indices, value) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
什么时候会
return None
,return None时一般是不支持,应该继续return NoneThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API有bug时返回None,已修改return None