From cbba4177ccc8d10f469743fb260ce549360041a2 Mon Sep 17 00:00:00 2001 From: co63oc Date: Mon, 28 Aug 2023 20:45:38 +0800 Subject: [PATCH] Add test --- paconvert/api_mapping.json | 17 ++++++ paconvert/api_matcher.py | 9 +++ tests/test_linalg_solve_triangular.py | 81 +++++++++++++++++++++++++++ 3 files changed, 107 insertions(+) create mode 100644 tests/test_linalg_solve_triangular.py diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index dd045688a..82b20f7dd 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -5862,6 +5862,23 @@ "A" ] }, + "torch.linalg.solve_triangular": { + "Matcher": "LinalgSolveTriangularMatcher", + "paddle_api": "paddle.linalg.triangular_solve", + "args_list": [ + "input", + "B", + "upper", + "left", + "unitriangular", + "out" + ], + "kwargs_change": { + "input": "x", + "B": "y", + "left": "transpose" + } + }, "torch.linalg.svdvals": { "Matcher": "LinalgSvdvalsMatcher", "paddle_api": "paddle.linalg.svd", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 4951d4000..189b491a7 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -4061,6 +4061,15 @@ def generate_code(self, kwargs): return code +class LinalgSolveTriangularMatcher(BaseMatcher): + def generate_code(self, kwargs): + new_kwargs = {} + if "left" in kwargs: + new_kwargs["transpose"] = f"not {kwargs.pop('left')}" + new_kwargs.update(kwargs) + return GenericMatcher.generate_code(self, new_kwargs) + + class QrMatcher(BaseMatcher): def generate_code(self, kwargs): some_v = kwargs.pop("some") if "some" in kwargs else None diff --git a/tests/test_linalg_solve_triangular.py b/tests/test_linalg_solve_triangular.py new file mode 100644 index 000000000..91f5ed22a --- /dev/null +++ b/tests/test_linalg_solve_triangular.py @@ -0,0 +1,81 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.linalg.solve_triangular") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[ 1.1527, -1.0753], [ 1.23, 0.7986]]) + B = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + result = torch.linalg.solve_triangular(A, B, upper=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[ 1.1527, -1.0753], [ 1.23, 0.7986]]) + B = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + result = torch.linalg.solve_triangular(input=A, B=B, upper=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[ 1.1527, -1.0753], [ 1.23, 0.7986]]) + B = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + result = torch.linalg.solve_triangular(input=A, B=B, unitriangular=True, upper=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[ 1.1527, -1.0753], [ 1.23, 0.7986]]) + B = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + out = torch.tensor([]) + result = torch.linalg.solve_triangular(A, B, upper=True, left=True, unitriangular=False, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[ 1.1527, -1.0753], [ 1.23, 0.7986]]) + B = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + out = torch.tensor([]) + result = torch.linalg.solve_triangular(input=A, B=B, upper=True, left=True, unitriangular=False, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"])