Skip to content

Commit

Permalink
Add api test case (#154)
Browse files Browse the repository at this point in the history
Add api matcher (#89)

* add fix api

* fix bug

* fix bug
  • Loading branch information
ROckDog22 authored Jul 5, 2023
1 parent 182bd89 commit 96300cb
Show file tree
Hide file tree
Showing 62 changed files with 2,766 additions and 59 deletions.
51 changes: 29 additions & 22 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,8 @@
},
"torch.Tensor.arcsin_": {},
"torch.Tensor.arcsinh": {
"Matcher": "TensorUnchangeMatcher"
"Matcher": "TensorFunc2PaddleFunc",
"paddle_api": "paddle.asinh"
},
"torch.Tensor.arcsinh_": {},
"torch.Tensor.arctan": {
Expand All @@ -226,7 +227,8 @@
"torch.Tensor.arctan2_": {},
"torch.Tensor.arctan_": {},
"torch.Tensor.arctanh": {
"Matcher": "TensorUnchangeMatcher"
"Matcher": "TensorFunc2PaddleFunc",
"paddle_api": "paddle.atanh"
},
"torch.Tensor.arctanh_": {},
"torch.Tensor.argmax": {
Expand Down Expand Up @@ -1755,8 +1757,8 @@
"torch.Tensor.orgqr": {},
"torch.Tensor.ormqr": {},
"torch.Tensor.outer": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.outer",
"Matcher": "TensorFunc2PaddleFunc",
"paddle_api": "paddle.outer",
"args_list": [
"vec2"
],
Expand Down Expand Up @@ -1820,8 +1822,8 @@
"paddle_api": "paddle.Tensor.flatten"
},
"torch.Tensor.reciprocal": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.reciprocal"
"Matcher": "TensorFunc2PaddleFunc",
"paddle_api": "paddle.reciprocal"
},
"torch.Tensor.reciprocal_": {
"Matcher": "GenericMatcher",
Expand All @@ -1837,8 +1839,8 @@
]
},
"torch.Tensor.remainder": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.remainder",
"Matcher": "TensorFunc2PaddleFunc",
"paddle_api": "paddle.remainder",
"args_list": [
"divisor"
],
Expand All @@ -1847,8 +1849,8 @@
}
},
"torch.Tensor.remainder_": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.remainder_",
"Matcher": "TensorFunc2PaddleFunc",
"paddle_api": "paddle.remainder_",
"args_list": [
"divisor"
],
Expand Down Expand Up @@ -1909,29 +1911,29 @@
"torch.Tensor.resolve_neg": {},
"torch.Tensor.retain_grad": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.retain_grad"
"paddle_api": "paddle.Tensor.retain_grads"
},
"torch.Tensor.retains_grad": {},
"torch.Tensor.roll": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.roll",
"args_list": [
"shifts",
"dim"
"dims"
],
"kwargs_change": {
"dim": "axis"
"dims": "axis"
}
},
"torch.Tensor.rot90": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.rot90",
"Matcher": "TensorFunc2PaddleFunc",
"paddle_api": "paddle.rot90",
"args_list": [
"k",
"dims"
],
"kwargs_change": {
"dims": "axis"
"dims": "axes"
}
},
"torch.Tensor.round": {
Expand Down Expand Up @@ -2008,7 +2010,7 @@
},
"torch.Tensor.slice_scatter": {},
"torch.Tensor.slogdet": {
"Matcher": "TensorSLogDetMatcher",
"Matcher": "SLogDetMatcher",
"paddle_api": "paddle.linalg.slogdet",
"args_list": [
"out"
Expand All @@ -2022,15 +2024,19 @@
]
},
"torch.Tensor.sort": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.sort",
"Matcher": "TensorFunc2PaddleFunc",
"paddle_api": "paddle.sort",
"args_list": [
"dim",
"descending"
"descending",
"stable"
],
"kwargs_change": {
"dim": "axis"
}
},
"unsupport_args": [
"stable"
]
},
"torch.Tensor.sparse_dim": {},
"torch.Tensor.sparse_mask": {},
Expand Down Expand Up @@ -2169,7 +2175,8 @@
}
},
"torch.Tensor.tan": {
"Matcher": "TensorUnchangeMatcher"
"Matcher": "TensorFunc2PaddleFunc",
"paddle_api": "paddle.tan"
},
"torch.Tensor.tan_": {},
"torch.Tensor.tanh": {
Expand Down
38 changes: 11 additions & 27 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2475,6 +2475,8 @@ def generate_code(self, kwargs):

class NumelMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "input" not in kwargs:
kwargs["input"] = self.paddleClass
return "{}.size".format(kwargs["input"])


Expand Down Expand Up @@ -2948,7 +2950,13 @@ def generate_code(self, kwargs):
out_v = kwargs.pop("out") if "out" in kwargs else None

if "input" in kwargs:
kwargs["A"] = kwargs.pop("input")
x_v = kwargs.pop("input")

elif "A" in kwargs:
x_v = kwargs.pop("A")

else:
x_v = self.paddleClass

if out_v:
API_TEMPLATE = textwrap.dedent(
Expand All @@ -2957,15 +2965,15 @@ def generate_code(self, kwargs):
paddle.assign(res[0], {}[0]), paddle.assign(res[1], {}[1])
"""
)
code = API_TEMPLATE.format(kwargs["A"], out_v, out_v)
code = API_TEMPLATE.format(x_v, out_v, out_v)
else:
API_TEMPLATE = textwrap.dedent(
"""
res = paddle.linalg.slogdet({})
res[0], res[1]
"""
)
code = API_TEMPLATE.format(kwargs["A"])
code = API_TEMPLATE.format(x_v)

return code

Expand Down Expand Up @@ -3769,27 +3777,3 @@ def generate_code(self, kwargs):
code = "{}".format(self.get_paddle_api())

return code


class TensorSLogDetMatcher(BaseMatcher):
def generate_code(self, kwargs):
out_v = kwargs.pop("out") if "out" in kwargs else None

if out_v:
API_TEMPLATE = textwrap.dedent(
"""
res = paddle.linalg.slogdet({})
paddle.assign(res[0], {}[0]), paddle.assign(res[1], {}[1])
"""
)
code = API_TEMPLATE.format(self.paddleClass, out_v, out_v)
else:
API_TEMPLATE = textwrap.dedent(
"""
res = paddle.linalg.slogdet({})
res[0], res[1]
"""
)
code = API_TEMPLATE.format(self.paddleClass)

return code
111 changes: 111 additions & 0 deletions tests/test_Module_named_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.Moudle.named_parameters")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
result = []
class TestForHook(nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = nn.Linear(in_features=2, out_features=2)
def forward(self, x):
x1 = self.linear_1(x)
return x, x, x1
a = TestForHook()
for a,b in a.named_parameters(prefix="wfs"):
result.append(b)
result = result[0]
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
result = []
class TestForHook(nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = nn.Linear(in_features=2, out_features=2)
def forward(self, x):
x1 = self.linear_1(x)
return x, x, x1
a = TestForHook()
for a,b in a.named_parameters(prefix="wfs", recurse=True):
result.append(b)
result = result[0]
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
result = []
class TestForHook(nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = nn.Linear(in_features=2, out_features=2)
def forward(self, x):
x1 = self.linear_1(x)
return x, x, x1
a = TestForHook()
for a,b in a.named_parameters(prefix="wfs", recurse=True, remove_duplicate = True):
result.append(b)
result = result[0]
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
result = []
class TestForHook(nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = nn.Linear(in_features=2, out_features=2)
def forward(self, x):
x1 = self.linear_1(x)
return x, x, x1
a = TestForHook()
for a,b in a.named_parameters(remove_duplicate = True):
result.append(b)
result = result[0]
"""
)
obj.run(pytorch_code, ["result"], check_value=False)
49 changes: 49 additions & 0 deletions tests/test_Module_register_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.Moudle.register_buffer")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2., 3.])
module1 = torch.nn.Module()
module1.register_buffer('buffer', x)
module2 = torch.nn.Module()
module2.register_module('submodule', module1)
result = module2.submodule.buffer
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2., 3.])
module1 = torch.nn.Module()
module1.register_buffer('buffer', x)
module2 = torch.nn.Module()
module2.register_module(name='submodule', module=module1)
result = module2.submodule.buffer
"""
)
obj.run(pytorch_code, ["result"])
Loading

0 comments on commit 96300cb

Please sign in to comment.