Skip to content

Commit

Permalink
转换规则 No.16/17/18 (#142)
Browse files Browse the repository at this point in the history
* 规则转换 No.16/17/18.

* 增加cummin,searchsorted转换测试。

* fix error

* fix error at SearchsortedMatcher

* fix code style error
  • Loading branch information
Li-fAngyU authored Jul 5, 2023
1 parent 2b1000d commit 748524a
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 31 deletions.
38 changes: 37 additions & 1 deletion paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -3398,6 +3398,19 @@
"device"
]
},
"torch.cummin": {
"Matcher": "TupleAssignMatcher",
"paddle_api": "paddle.cummin",
"args_list": [
"input",
"dim",
"out"
],
"kwargs_change": {
"input": "x",
"dim": "axis"
}
},
"torch.cumprod": {
"Matcher": "CumprodMatcher",
"paddle_api": "paddle.cumprod",
Expand Down Expand Up @@ -8418,6 +8431,18 @@
"dtype": "paddle.float32"
}
},
"torch.searchsorted": {
"Matcher": "SearchsortedMatcher",
"args_list": [
"sorted_sequence",
"values",
"out_int32",
"right",
"side",
"out",
"sorter"
]
},
"torch.seed": {
"Matcher": "SeedMatcher"
},
Expand Down Expand Up @@ -9214,7 +9239,18 @@
"tensor": "x"
}
},
"torch.vander": {},
"torch.vander": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.vander",
"args_list": [
"x",
"N",
"increasing"
],
"kwargs_change": {
"N": "n"
}
},
"torch.var": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.var",
Expand Down
24 changes: 24 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2867,6 +2867,30 @@ def generate_code(self, kwargs):
return code


class SearchsortedMatcher(BaseMatcher):
def generate_code(self, kwargs):

if "side" in kwargs:
kwargs["right"] = kwargs.pop("side").strip("\n") + "== 'right'"

if "sorter" in kwargs and kwargs["sorter"] is not None:
kwargs[
"sorted_sequence"
] += ".take_along_axis(axis=-1, indices = {})".format(
kwargs.pop("sorter").strip("\n")
)

code = "paddle.searchsorted({})".format(self.kwargs_to_str(kwargs))

if "out" in kwargs and kwargs["out"] is not None:
out_v = kwargs.pop("out").strip("\n")
code = "paddle.assign(paddle.searchsorted({}), output={})".format(
self.kwargs_to_str(kwargs), out_v
)

return code


class SincMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "input" not in kwargs:
Expand Down
118 changes: 118 additions & 0 deletions tests/test_cummin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.cummin")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0]])
result = torch.cummin(x, 0)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0]])
result = torch.cummin(x, dim=1)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0]])
result = torch.cummin(input=x, dim=1)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0]])
values = torch.tensor([[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0]]).float()
indices = torch.tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
out = (values, indices)
result = torch.cummin(x, 0, out=(values, indices))
"""
)
obj.run(pytorch_code, ["result", "out"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0]])
values = torch.tensor([[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0]]).float()
indices = torch.tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
out = (values, indices)
result = torch.cummin(x, dim = 0, out=(values, indices))
"""
)
obj.run(pytorch_code, ["result", "out"])


def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0]])
values = torch.tensor([[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0]]).float()
indices = torch.tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])
out = (values, indices)
result = torch.cummin(input = x, dim =0, out=(values, indices))
"""
)
obj.run(pytorch_code, ["result", "out"])
137 changes: 137 additions & 0 deletions tests/test_searchsorted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.searchsorted")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 1, 3, 5, 7, 9],
[ 2, 4, 6, 8, 10]])
values = torch.tensor([[3, 6, 9],
[3, 6, 9]])
result = torch.searchsorted(x, values)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 1, 3, 5, 7, 9],
[ 2, 4, 6, 8, 10]])
values = torch.tensor([[3, 6, 9],
[3, 6, 9]])
result = torch.searchsorted(x, values, out_int32 = True)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 1, 3, 5, 7, 9],
[ 2, 4, 6, 8, 10]])
values = torch.tensor([[3, 6, 9],
[3, 6, 9]])
result = torch.searchsorted(x, values, right = True)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 1, 3, 5, 7, 9],
[ 2, 4, 6, 8, 10]])
values = torch.tensor([[3, 6, 9],
[3, 6, 9]])
result = torch.searchsorted(x, values, side = 'right')
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 1, 3, 5, 7, 9],
[ 2, 4, 6, 8, 10]])
values = torch.tensor([[3, 6, 9],
[3, 6, 9]])
out = torch.tensor([[3, 6, 9],
[3, 6, 9]])
result = torch.searchsorted(x, values, out = out)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 1, 3, 9, 7, 5],
[ 2, 4, 6, 8, 10]])
values = torch.tensor([[3, 6, 9],
[3, 6, 9]])
sorter = torch.argsort(x)
result = torch.searchsorted(x, values, sorter = sorter)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 1, 3, 9, 7, 5],
[ 2, 4, 6, 8, 10]])
values = torch.tensor([[3, 6, 9],
[3, 6, 9]])
out = torch.tensor([[3, 6, 9],
[3, 6, 9]])
sorter = torch.argsort(x)
result = torch.searchsorted(x, values, right = True, side = 'right', out = out, sorter = sorter)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 1, 3, 5, 7, 9],
[ 2, 4, 6, 8, 10]])
values = torch.tensor([[3, 6, 9],
[3, 6, 9]])
result = torch.searchsorted(x, values, right = False, side = 'right')
"""
)
obj.run(pytorch_code, ["result"])
Loading

0 comments on commit 748524a

Please sign in to comment.