Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

转换规则No64 torch.nn.functional.max_pool1d #120

Merged
merged 15 commits into from
Aug 28, 2023
Merged
23 changes: 23 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -8732,5 +8732,28 @@
"kwargs_change": {
"eps": "epsilon"
}
},
"torch.nn.functional.max_pool1d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.max_pool1d",
"args_list": [
"input",
"kernel_size",
"stride",
"padding",
"dilation",
"ceil_mode",
"return_indices"
],
"unsupport_args": [
"dilation",
"return_indices"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个return_indices应该是支持转化的,paddle和torch返回类型不同,int32和int64,属于风格差异,通过转化类型可以在单测里面进行比较

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

很抱歉隔了很久才开始处理这个问题,麻烦帮忙check一下增加的单元测试中的case5和case6,都使用了return index,但是case6的return shape是不同的

],
"kwargs_change": {
"input": "x"
},
"paddle_default_kwargs": {
"return_mask": "False"
}
}
}
92 changes: 92 additions & 0 deletions tests/test_nn_functional_max_pool1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.functional.max_pool1d")


Liyulingyue marked this conversation as resolved.
Show resolved Hide resolved
def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn.functional as F
input = torch.tensor([[[ 1.1524, 0.4714, 0.2857],
[-1.2533, -0.9829, -1.0981],
[ 0.1507, -1.1431, -2.0361]]])
result = F.max_pool1d(input , 3)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn.functional as F
input = torch.tensor([[[ 1.1524, 0.4714, 0.2857],
[-1.2533, -0.9829, -1.0981],
[ 0.1507, -1.1431, -2.0361]]])
result = F.max_pool1d(input , 3, stride=2)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn.functional as F
input = torch.tensor([[[ 1.1524, 0.4714, 0.2857, 0.4586, 0.9876, 0.5487],
[-1.2533, -0.9829, -1.0981, 0.7655, 0.8541, 0.9873],
[ 0.1507, -1.1431, -2.0361, 0.2344, 0.5675, 0.1546]]])
result = F.max_pool1d(input , 5, stride=2, padding=2)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn.functional as F
input = torch.tensor([[[ 1.1524, 0.4714, 0.2857, 0.4586, 0.9876, 0.5487],
[-1.2533, -0.9829, -1.0981, 0.7655, 0.8541, 0.9873],
[ 0.1507, -1.1431, -2.0361, 0.2344, 0.5675, 0.1546]]])
result = F.max_pool1d(input , 5, stride=2, padding=2, ceil_mode=True)
"""
)
obj.run(pytorch_code, ["result"])


'''
# if enable the return_indices, the results of torch and paddle are different
def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn.functional as F
input = torch.tensor([[[ 1.1524, 0.4714, 0.2857, 0.4586, 0.9876, 0.5487],
[-1.2533, -0.9829, -1.0981, 0.7655, 0.8541, 0.9873],
[ 0.1507, -1.1431, -2.0361, 0.2344, 0.5675, 0.1546]]])
result = F.max_pool1d(input , 5, stride=2, padding=2, ceil_mode=True, return_indices=True)
"""
)
obj.run(pytorch_code, ["result"])
'''