diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index a8075c746..b57a58e05 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -3699,12 +3699,15 @@ } }, "torch.exp": { - "Matcher": "ExpMatcher", + "Matcher": "GenericMatcher", "paddle_api": "paddle.exp", "args_list": [ "input", "out" - ] + ], + "kwargs_change": { + "input": "x" + } }, "torch.exp2": { "Matcher": "Exp2Matcher", @@ -3714,7 +3717,7 @@ ] }, "torch.expm1": { - "Matcher": "ExpMatcher", + "Matcher": "GenericMatcher", "paddle_api": "paddle.expm1", "args_list": [ "input", @@ -4746,6 +4749,17 @@ "input": "x" } }, + "torch.linalg.multi_dot": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.linalg.multi_dot", + "args_list": [ + "tensors", + "out" + ], + "kwargs_change": { + "tensors": "x" + } + }, "torch.linalg.norm": { "Matcher": "GenericMatcher", "paddle_api": "paddle.linalg.norm", @@ -8660,6 +8674,17 @@ "input": "x" } }, + "torch.special.expm1": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.expm1", + "args_list": [ + "input", + "out" + ], + "kwargs_change": { + "input": "x" + } + }, "torch.special.log1p": { "Matcher": "GenericMatcher", "paddle_api": "paddle.log1p", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 32e6134b3..9e5e0a3a7 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -2419,22 +2419,6 @@ def generate_code(self, kwargs): return code -class ExpMatcher(BaseMatcher): - def generate_code(self, kwargs): - if "input" in kwargs: - kwargs["x"] = "(" + kwargs.pop("input").strip("\n") + ").astype('float32')" - - if "out" in kwargs and kwargs["out"] is not None: - out_v = kwargs.pop("out").strip("\n") - code = "paddle.assign({}({}), output={})".format( - self.get_paddle_api(), self.kwargs_to_str(kwargs), out_v - ) - else: - code = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(kwargs)) - - return code - - class TensorSVDMatcher(BaseMatcher): def generate_code(self, kwargs): diff --git a/tests/test_dsplit.py b/tests/test_dsplit.py new file mode 100644 index 000000000..3a03f215d --- /dev/null +++ b/tests/test_dsplit.py @@ -0,0 +1,72 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.dsplit") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.arange(16.0).reshape(2, 2, 4) + result = torch.dsplit(a, 2) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle does not support this function temporarily", + ) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.arange(16.0).reshape(2, 2, 4) + result = torch.dsplit(a, [2,2]) + if len(result) > 2: + result = (result[0], result[2]) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle does not support this function temporarily", + ) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.arange(12).reshape(3, 2, 2) + result = torch.dsplit(a, indices=[1,1]) + if len(result) > 2: + result = (result[0], result[2]) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle does not support this function temporarily", + ) diff --git a/tests/test_linalg_multi_dot.py b/tests/test_linalg_multi_dot.py new file mode 100644 index 000000000..9143c0d7c --- /dev/null +++ b/tests/test_linalg_multi_dot.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.multi_dot") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = [torch.tensor([1, 2], dtype=torch.float32), torch.tensor([2, 3], dtype=torch.float32)] + result = torch.linalg.multi_dot(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = [torch.tensor([1, 2], dtype=torch.float32), torch.tensor([2, 3], dtype=torch.float32)] + result = torch.linalg.multi_dot(tensors=input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = [torch.tensor([1, 2], dtype=torch.float32), torch.tensor([2, 3], dtype=torch.float32)] + out = torch.tensor([]) + result = torch.linalg.multi_dot(input, out=out) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_special_expm1.py b/tests/test_special_expm1.py new file mode 100644 index 000000000..289546561 --- /dev/null +++ b/tests/test_special_expm1.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.special.expm1") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.special.expm1(torch.tensor([0., -2., 3.])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([-1., -2., 3.]) + result = torch.special.expm1(a) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + a = [-1, -2, 3] + out = torch.tensor(a, dtype=torch.float32) + result = torch.special.expm1(torch.tensor(a), out=out) + """ + ) + obj.run(pytorch_code, ["out"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.tensor([-1, -2, 3]) + result = torch.special.expm1(input=a) + """ + ) + obj.run(pytorch_code, ["result"])