From 8629c0455f88341c0f966a2b2930745892f23b6d Mon Sep 17 00:00:00 2001 From: co63oc Date: Wed, 5 Jul 2023 20:55:56 +0800 Subject: [PATCH] Add tests --- paconvert/api_mapping.json | 111 +++++++++++++++++++++ paconvert/api_matcher.py | 41 +++++++- tests/test_autograd_backward.py | 77 ++++++++++++++ tests/test_autograd_functional_hessian.py | 70 +++++++++++++ tests/test_autograd_functional_jacobian.py | 69 +++++++++++++ tests/test_autograd_functional_jvp.py | 73 ++++++++++++++ tests/test_autograd_functional_vjp.py | 74 ++++++++++++++ tests/test_cuda_device.py | 82 +++++++++++++++ 8 files changed, 595 insertions(+), 2 deletions(-) create mode 100644 tests/test_autograd_backward.py create mode 100644 tests/test_autograd_functional_hessian.py create mode 100644 tests/test_autograd_functional_jacobian.py create mode 100644 tests/test_autograd_functional_jvp.py create mode 100644 tests/test_autograd_functional_vjp.py create mode 100644 tests/test_cuda_device.py diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 3f169ba4c..2075c52e3 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -2834,6 +2834,107 @@ "input": "x" } }, + "torch.autograd.backward": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.autograd.backward", + "args_list": [ + "tensors", + "grad_tensors", + "retain_graph", + "create_graph", + "grad_variables", + "inputs" + ], + "unsupport_args": [ + "create_graph", + "grad_variables", + "inputs" + ] + }, + "torch.autograd.functional.hessian": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.incubate.autograd.Hessian", + "args_list": [ + "func", + "inputs", + "create_graph", + "strict", + "vectorize", + "outer_jacobian_strategy" + ], + "unsupport_args": [ + "create_graph", + "strict", + "vectorize", + "outer_jacobian_strategy" + ], + "kwargs_change": { + "inputs": "xs" + }, + "paddle_default_kwargs": { + "is_batched": false + } + }, + "torch.autograd.functional.jacobian": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.incubate.autograd.Jacobian", + "args_list": [ + "func", + "inputs", + "create_graph", + "strict", + "vectorize", + "strategy" + ], + "unsupport_args": [ + "create_graph", + "strict", + "vectorize", + "strategy" + ], + "kwargs_change": { + "inputs": "xs" + }, + "paddle_default_kwargs": { + "is_batched": false + } + }, + "torch.autograd.functional.jvp": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.incubate.autograd.jvp", + "args_list": [ + "func", + "inputs", + "v", + "create_graph", + "strict" + ], + "unsupport_args": [ + "create_graph", + "strict" + ], + "kwargs_change": { + "inputs": "xs" + } + }, + "torch.autograd.functional.vjp": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.incubate.autograd.vjp", + "args_list": [ + "func", + "inputs", + "v", + "create_graph", + "strict" + ], + "unsupport_args": [ + "create_graph", + "strict" + ], + "kwargs_change": { + "inputs": "xs" + } + }, "torch.autograd.grad": { "Matcher": "GenericMatcher", "paddle_api": "paddle.grad", @@ -3309,6 +3410,16 @@ "device": "device" } }, + "torch.cuda.device": { + "Matcher": "CudaDeviceMatcher", + "paddle_api": "paddle.CUDAPlace", + "args_list": [ + "device" + ], + "kwargs_change": { + "device": "id" + } + }, "torch.cuda.device_count": { "Matcher": "GenericMatcher", "paddle_api": "paddle.device.cuda.device_count" diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index f6148c70d..20f60c994 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -13,6 +13,7 @@ # limitations under the License. import ast +import re import textwrap import astor @@ -28,7 +29,7 @@ def get_paddle_api(self): return self.paddle_api return self.api_mapping["paddle_api"] - def generate_code(self, kwargs): + def generate_code(self, kwargs, args=[]): kwargs_change = {} if "kwargs_change" in self.api_mapping: kwargs_change = self.api_mapping["kwargs_change"] @@ -80,7 +81,9 @@ def generate_code(self, kwargs): if "out" in new_kwargs: out_v = new_kwargs.pop("out") - res = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(new_kwargs)) + res = "{}({})".format( + self.get_paddle_api(), self.args_and_kwargs_to_str(args, new_kwargs) + ) if dtype_v: res += ".astype({})".format(dtype_v) @@ -3777,3 +3780,37 @@ def generate_code(self, kwargs): code = "{}".format(self.get_paddle_api()) return code + + +class CudaDeviceMatcher(BaseMatcher): + def generate_code(self, kwargs): + if not kwargs["device"].strip("()").isdigit(): + device = kwargs["device"] + if ( + "replace('cuda', 'gpu')," in device + or 'replace("cuda", "gpu"),' in device + ): + m = re.search(r"\(([0-9]+)\)", device) + if m: + kwargs["device"] = m.group(1) + else: + return None + elif ( + "replace('cuda', 'gpu')" in device or 'replace("cuda", "gpu")' in device + ): + kwargs["device"] = 0 + else: + return None + + kwargs_change = {} + if "kwargs_change" in self.api_mapping: + kwargs_change = self.api_mapping["kwargs_change"] + + args = [] + new_kwargs = {} + for ele in kwargs: + if ele in kwargs_change and kwargs_change[ele] == "id": + args.append(kwargs[ele]) + else: + new_kwargs[ele] = kwargs[ele] + return GenericMatcher.generate_code(self, new_kwargs, args) diff --git a/tests/test_autograd_backward.py b/tests/test_autograd_backward.py new file mode 100644 index 000000000..35972de59 --- /dev/null +++ b/tests/test_autograd_backward.py @@ -0,0 +1,77 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.autograd.backward") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, requires_grad=True) + y = torch.tensor([[3, 2], [3, 4]], dtype=torch.float32) + + grad_tensor1 = torch.tensor([[1,2], [2, 3]], dtype=torch.float32) + grad_tensor2 = torch.tensor([[1,1], [1, 1]], dtype=torch.float32) + + z1 = torch.matmul(x, y) + z2 = torch.matmul(x, y) + + torch.autograd.backward([z1, z2], [grad_tensor1, grad_tensor2], True) + x.grad.requires_grad=False + result = x.grad + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, requires_grad=True) + y = torch.tensor([[3, 2], [3, 4]], dtype=torch.float32) + + grad_tensor1 = torch.tensor([[1,2], [2, 3]], dtype=torch.float32) + grad_tensor2 = torch.tensor([[1,1], [1, 1]], dtype=torch.float32) + + z1 = torch.matmul(x, y) + z2 = torch.matmul(x, y) + + torch.autograd.backward([z1, z2], [grad_tensor1, grad_tensor2], retain_graph=False) + x.grad.requires_grad=False + result = x.grad + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, requires_grad=True) + z1 = x.sum() + + torch.autograd.backward([z1]) + x.grad.requires_grad=False + result = x.grad + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_autograd_functional_hessian.py b/tests/test_autograd_functional_hessian.py new file mode 100644 index 000000000..dd481f2b8 --- /dev/null +++ b/tests/test_autograd_functional_hessian.py @@ -0,0 +1,70 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.autograd.functional.hessian") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + def func(x): + return torch.sum(x * x) + x = torch.rand(2, 2) + h = torch.autograd.functional.hessian(func, x) + result = h[:] + result.requires_grad = False + result = torch.flatten(result) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + + def func(x): + return 2 * torch.sum(x * x + 3 * x) + + x = torch.rand(2, 2) + h = torch.autograd.functional.hessian(func, x) + result = h[:] + result.requires_grad = False + result = torch.flatten(result) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + + def func(x): + return torch.sum(x) + + x = torch.tensor([1.0, 2.0]) + h = torch.autograd.functional.hessian(func, x) + result = h[:] + result.requires_grad = False + result = torch.flatten(result) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_autograd_functional_jacobian.py b/tests/test_autograd_functional_jacobian.py new file mode 100644 index 000000000..9bac779ac --- /dev/null +++ b/tests/test_autograd_functional_jacobian.py @@ -0,0 +1,69 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.autograd.functional.jacobian") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + + def func(x): + return x * x + + x = torch.tensor([1., 2.]) + J = torch.autograd.functional.jacobian(func, x) + result = J[:] + result.requires_grad = False + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + + def func(x): + return torch.cos(x) + + x = torch.tensor([1., 2.]) + J = torch.autograd.functional.jacobian(func, x) + result = J[:] + result.requires_grad = False + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + + def func(x): + return torch.log(x) + + x = torch.tensor([1., 2.]) + J = torch.autograd.functional.jacobian(func, x) + result = J[:] + result.requires_grad = False + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_autograd_functional_jvp.py b/tests/test_autograd_functional_jvp.py new file mode 100644 index 000000000..3050ae22d --- /dev/null +++ b/tests/test_autograd_functional_jvp.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.autograd.functional.jvp") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + def func(x): + return torch.sum(x * x) + x = torch.ones(2, 3) + v = torch.ones(2, 3) + h = torch.autograd.functional.jvp(func, x, v) + result = h[:] + for item in result: + item.requires_grad = False + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + + def func(x): + return 2 * torch.sum(x * x + 3 * x) + + x = torch.arange(6, dtype=torch.float32).reshape(2, 3) + v = torch.ones(2, 3) + h = torch.autograd.functional.jvp(func, x, v) + result = h[:] + for item in result: + item.requires_grad = False + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + + def func(x): + return torch.sum(x) + + x = torch.tensor([1.0, 2.0]) + v = torch.ones(2) + h = torch.autograd.functional.jvp(func, x, v) + result = h[:] + for item in result: + item.requires_grad = False + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_autograd_functional_vjp.py b/tests/test_autograd_functional_vjp.py new file mode 100644 index 000000000..8139be63d --- /dev/null +++ b/tests/test_autograd_functional_vjp.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.autograd.functional.vjp") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + def func(x): + return x.sum(dim=1) + + x = torch.ones(2, 2) + v = torch.ones(2) + h = torch.autograd.functional.vjp(func, x, v) + result = h[:] + for item in result: + item.requires_grad = False + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + + def func(x): + return x.sum(dim=1) + + x = torch.arange(6, dtype=torch.float32).reshape(2, 3) + v = torch.ones(2) + h = torch.autograd.functional.vjp(func, x, v) + result = h[:] + for item in result: + item.requires_grad = False + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + + def func(x): + return x * x + + x = torch.tensor([1.0, 2.0]) + v = torch.ones(2) + h = torch.autograd.functional.vjp(func, x, v) + result = h[:] + for item in result: + item.requires_grad = False + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_cuda_device.py b/tests/test_cuda_device.py new file mode 100644 index 000000000..f7695585e --- /dev/null +++ b/tests/test_cuda_device.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import textwrap + +import paddle +from apibase import APIBase + + +class CudaDeviceAPIBase(APIBase): + def compare(self, name, pytorch_result, paddle_result, check_value=True): + if ( + isinstance(paddle_result, paddle.fluid.libpaddle.Place) + or paddle_result is None + ): + return True + return False + + +obj = CudaDeviceAPIBase("torch.cuda.device") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + if torch.cuda.is_available(): + result = torch.cuda.device(device=0) + else: + result = None + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + if torch.cuda.is_available(): + result = torch.cuda.device(0) + else: + result = None + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + if torch.cuda.is_available(): + result = torch.cuda.device(torch.device("cuda")) + else: + result = None + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + if torch.cuda.is_available(): + result = torch.cuda.device(torch.device("cuda", 0)) + else: + result = None + """ + ) + obj.run(pytorch_code, ["result"])