-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
转换规则No64 torch.nn.functional.max_pool1d #120
Merged
Merged
Changes from 7 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
41682d7
torch.nn.functional.max_pool1d
Liyulingyue 660ef1a
torch.nn.functional.max_pool1d
Liyulingyue 143ed53
torch.nn.functional.max_pool1d
Liyulingyue 637c070
torch.nn.functional.max_pool1d
Liyulingyue d6dc8c7
torch.nn.functional.max_pool1d
Liyulingyue 05cc1f8
torch.nn.functional.max_pool1d
Liyulingyue 0ca50e8
torch.nn.functional.max_pool1d
Liyulingyue c1b849a
Merge branch 'master' into case64
Liyulingyue 91af61b
add json
Liyulingyue 21e0a11
Update test_nn_functional_max_pool1d.py
Liyulingyue 13007b0
Update test_nn_functional_max_pool1d.py
Liyulingyue c4e7c4d
Update test_nn_functional_max_pool1d.py
Liyulingyue 71556d7
Apply suggestions from code review
Liyulingyue 74801d4
Apply suggestions from code review
Liyulingyue bc91f82
Merge branch 'PaddlePaddle:master' into case64
Liyulingyue File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import textwrap | ||
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.nn.functional.max_pool1d") | ||
|
||
|
||
Liyulingyue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def test_case_1(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
import torch.nn.functional as F | ||
input = torch.tensor([[[ 1.1524, 0.4714, 0.2857], | ||
[-1.2533, -0.9829, -1.0981], | ||
[ 0.1507, -1.1431, -2.0361]]]) | ||
result = F.max_pool1d(input , 3) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_2(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
import torch.nn.functional as F | ||
input = torch.tensor([[[ 1.1524, 0.4714, 0.2857], | ||
[-1.2533, -0.9829, -1.0981], | ||
[ 0.1507, -1.1431, -2.0361]]]) | ||
result = F.max_pool1d(input , 3, stride=2) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_3(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
import torch.nn.functional as F | ||
input = torch.tensor([[[ 1.1524, 0.4714, 0.2857, 0.4586, 0.9876, 0.5487], | ||
[-1.2533, -0.9829, -1.0981, 0.7655, 0.8541, 0.9873], | ||
[ 0.1507, -1.1431, -2.0361, 0.2344, 0.5675, 0.1546]]]) | ||
result = F.max_pool1d(input , 5, stride=2, padding=2) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_4(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
import torch.nn.functional as F | ||
input = torch.tensor([[[ 1.1524, 0.4714, 0.2857, 0.4586, 0.9876, 0.5487], | ||
[-1.2533, -0.9829, -1.0981, 0.7655, 0.8541, 0.9873], | ||
[ 0.1507, -1.1431, -2.0361, 0.2344, 0.5675, 0.1546]]]) | ||
result = F.max_pool1d(input , 5, stride=2, padding=2, ceil_mode=True) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
''' | ||
# if enable the return_indices, the results of torch and paddle are different | ||
def test_case_5(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
import torch.nn.functional as F | ||
input = torch.tensor([[[ 1.1524, 0.4714, 0.2857, 0.4586, 0.9876, 0.5487], | ||
[-1.2533, -0.9829, -1.0981, 0.7655, 0.8541, 0.9873], | ||
[ 0.1507, -1.1431, -2.0361, 0.2344, 0.5675, 0.1546]]]) | ||
result = F.max_pool1d(input , 5, stride=2, padding=2, ceil_mode=True, return_indices=True) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
''' |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个return_indices应该是支持转化的,paddle和torch返回类型不同,int32和int64,属于风格差异,通过转化类型可以在单测里面进行比较
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
很抱歉隔了很久才开始处理这个问题,麻烦帮忙check一下增加的单元测试中的case5和case6,都使用了return index,但是case6的return shape是不同的