diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 42f88efe6..02c198665 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -5811,6 +5811,23 @@ "A" ] }, + "torch.linalg.solve_triangular": { + "Matcher": "LinalgSolveTriangularMatcher", + "paddle_api": "paddle.linalg.triangular_solve", + "args_list": [ + "input", + "B", + "upper", + "left", + "unitriangular", + "out" + ], + "kwargs_change": { + "input": "x", + "B": "y", + "left": "transpose" + } + }, "torch.linalg.svdvals": { "Matcher": "LinalgSvdvalsMatcher", "paddle_api": "paddle.linalg.svd", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 7fa52b0d5..a76770efb 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3886,6 +3886,18 @@ def generate_code(self, kwargs): return GenericMatcher.generate_code(self, kwargs) +class LinalgSolveTriangularMatcher(BaseMatcher): + def generate_code(self, kwargs): + new_kwargs = {} + left = kwargs.pop("left").strip("(").strip(")") + if left == "True": + new_kwargs["transpose"] = "False" + if left == "False": + new_kwargs["transpose"] = "True" + new_kwargs.update(kwargs) + return GenericMatcher.generate_code(self, new_kwargs) + + class QrMatcher(BaseMatcher): def generate_code(self, kwargs): some_v = kwargs.pop("some") if "some" in kwargs else None diff --git a/tests/test_linalg_solve_triangular.py b/tests/test_linalg_solve_triangular.py new file mode 100644 index 000000000..91f5ed22a --- /dev/null +++ b/tests/test_linalg_solve_triangular.py @@ -0,0 +1,81 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.solve_triangular") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[ 1.1527, -1.0753], [ 1.23, 0.7986]]) + B = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + result = torch.linalg.solve_triangular(A, B, upper=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[ 1.1527, -1.0753], [ 1.23, 0.7986]]) + B = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + result = torch.linalg.solve_triangular(input=A, B=B, upper=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[ 1.1527, -1.0753], [ 1.23, 0.7986]]) + B = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + result = torch.linalg.solve_triangular(input=A, B=B, unitriangular=True, upper=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[ 1.1527, -1.0753], [ 1.23, 0.7986]]) + B = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + out = torch.tensor([]) + result = torch.linalg.solve_triangular(A, B, upper=True, left=True, unitriangular=False, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[ 1.1527, -1.0753], [ 1.23, 0.7986]]) + B = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + out = torch.tensor([]) + result = torch.linalg.solve_triangular(input=A, B=B, upper=True, left=True, unitriangular=False, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"])