diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 2d74e0590..b52556792 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -3398,6 +3398,19 @@ "device" ] }, + "torch.cummin": { + "Matcher": "TupleAssignMatcher", + "paddle_api": "paddle.cummin", + "args_list": [ + "input", + "dim", + "out" + ], + "kwargs_change": { + "input": "x", + "dim": "axis" + } + }, "torch.cumprod": { "Matcher": "CumprodMatcher", "paddle_api": "paddle.cumprod", @@ -8418,6 +8431,18 @@ "dtype": "paddle.float32" } }, + "torch.searchsorted": { + "Matcher": "SearchsortedMatcher", + "args_list": [ + "sorted_sequence", + "values", + "out_int32", + "right", + "side", + "out", + "sorter" + ] + }, "torch.seed": { "Matcher": "SeedMatcher" }, @@ -9214,7 +9239,18 @@ "tensor": "x" } }, - "torch.vander": {}, + "torch.vander": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.vander", + "args_list": [ + "x", + "N", + "increasing" + ], + "kwargs_change": { + "N": "n" + } + }, "torch.var": { "Matcher": "GenericMatcher", "paddle_api": "paddle.var", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 77babe4d0..fe390c319 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -2867,6 +2867,30 @@ def generate_code(self, kwargs): return code +class SearchsortedMatcher(BaseMatcher): + def generate_code(self, kwargs): + + if "side" in kwargs: + kwargs["right"] = kwargs.pop("side").strip("\n") + "== 'right'" + + if "sorter" in kwargs and kwargs["sorter"] is not None: + kwargs[ + "sorted_sequence" + ] += ".take_along_axis(axis=-1, indices = {})".format( + kwargs.pop("sorter").strip("\n") + ) + + code = "paddle.searchsorted({})".format(self.kwargs_to_str(kwargs)) + + if "out" in kwargs and kwargs["out"] is not None: + out_v = kwargs.pop("out").strip("\n") + code = "paddle.assign(paddle.searchsorted({}), output={})".format( + self.kwargs_to_str(kwargs), out_v + ) + + return code + + class SincMatcher(BaseMatcher): def generate_code(self, kwargs): if "input" not in kwargs: diff --git a/tests/test_cummin.py b/tests/test_cummin.py new file mode 100644 index 000000000..4e92c0be8 --- /dev/null +++ b/tests/test_cummin.py @@ -0,0 +1,118 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.cummin") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 1.0, 1.0], + [2.0, 2.0, 2.0], + [3.0, 3.0, 3.0]]) + result = torch.cummin(x, 0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 1.0, 1.0], + [2.0, 2.0, 2.0], + [3.0, 3.0, 3.0]]) + result = torch.cummin(x, dim=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 1.0, 1.0], + [2.0, 2.0, 2.0], + [3.0, 3.0, 3.0]]) + result = torch.cummin(input=x, dim=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 1.0, 1.0], + [2.0, 2.0, 2.0], + [3.0, 3.0, 3.0]]) + values = torch.tensor([[1.0, 1.0, 1.0], + [2.0, 2.0, 2.0], + [3.0, 3.0, 3.0]]).float() + indices = torch.tensor([[1, 1, 1], + [2, 2, 2], + [3, 3, 3]]) + out = (values, indices) + result = torch.cummin(x, 0, out=(values, indices)) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 1.0, 1.0], + [2.0, 2.0, 2.0], + [3.0, 3.0, 3.0]]) + values = torch.tensor([[1.0, 1.0, 1.0], + [2.0, 2.0, 2.0], + [3.0, 3.0, 3.0]]).float() + indices = torch.tensor([[1, 1, 1], + [2, 2, 2], + [3, 3, 3]]) + out = (values, indices) + result = torch.cummin(x, dim = 0, out=(values, indices)) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 1.0, 1.0], + [2.0, 2.0, 2.0], + [3.0, 3.0, 3.0]]) + values = torch.tensor([[1.0, 1.0, 1.0], + [2.0, 2.0, 2.0], + [3.0, 3.0, 3.0]]).float() + indices = torch.tensor([[1, 1, 1], + [2, 2, 2], + [3, 3, 3]]) + out = (values, indices) + result = torch.cummin(input = x, dim =0, out=(values, indices)) + """ + ) + obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_searchsorted.py b/tests/test_searchsorted.py new file mode 100644 index 000000000..a65405607 --- /dev/null +++ b/tests/test_searchsorted.py @@ -0,0 +1,137 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.searchsorted") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + values = torch.tensor([[3, 6, 9], + [3, 6, 9]]) + result = torch.searchsorted(x, values) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + values = torch.tensor([[3, 6, 9], + [3, 6, 9]]) + result = torch.searchsorted(x, values, out_int32 = True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + values = torch.tensor([[3, 6, 9], + [3, 6, 9]]) + result = torch.searchsorted(x, values, right = True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + values = torch.tensor([[3, 6, 9], + [3, 6, 9]]) + result = torch.searchsorted(x, values, side = 'right') + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + values = torch.tensor([[3, 6, 9], + [3, 6, 9]]) + out = torch.tensor([[3, 6, 9], + [3, 6, 9]]) + result = torch.searchsorted(x, values, out = out) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 1, 3, 9, 7, 5], + [ 2, 4, 6, 8, 10]]) + values = torch.tensor([[3, 6, 9], + [3, 6, 9]]) + sorter = torch.argsort(x) + result = torch.searchsorted(x, values, sorter = sorter) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 1, 3, 9, 7, 5], + [ 2, 4, 6, 8, 10]]) + values = torch.tensor([[3, 6, 9], + [3, 6, 9]]) + out = torch.tensor([[3, 6, 9], + [3, 6, 9]]) + sorter = torch.argsort(x) + result = torch.searchsorted(x, values, right = True, side = 'right', out = out, sorter = sorter) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 1, 3, 5, 7, 9], + [ 2, 4, 6, 8, 10]]) + values = torch.tensor([[3, 6, 9], + [3, 6, 9]]) + result = torch.searchsorted(x, values, right = False, side = 'right') + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_vander.py b/tests/test_vander.py index aedf77b70..f7a52a094 100644 --- a/tests/test_vander.py +++ b/tests/test_vander.py @@ -27,12 +27,7 @@ def test_case_1(): result = torch.vander(x) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle has no corresponding api tentatively", - ) + obj.run(pytorch_code, ["result"]) def test_case_2(): @@ -43,27 +38,17 @@ def test_case_2(): result = torch.vander(x, 3) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle has no corresponding api tentatively", - ) + obj.run(pytorch_code, ["result"]) def test_case_3(): pytorch_code = textwrap.dedent( """ import torch - result = torch.vander(input=torch.tensor([1, 2, 3, 5]), N=3) + result = torch.vander(x=torch.tensor([1, 2, 3, 5]), N=3) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle has no corresponding api tentatively", - ) + obj.run(pytorch_code, ["result"]) def test_case_4(): @@ -74,12 +59,7 @@ def test_case_4(): result = torch.vander(x, 5, increasing=True) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle has no corresponding api tentatively", - ) + obj.run(pytorch_code, ["result"]) def test_case_5(): @@ -91,9 +71,15 @@ def test_case_5(): result = torch.vander(x, 5, increasing) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="paddle has no corresponding api tentatively", + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1, 2, 3, 5]) + result = torch.vander(x = x, N = 5, increasing=True) + """ ) + obj.run(pytorch_code, ["result"])