From 182bd89542c7641811e5f60debd501352f2e66ea Mon Sep 17 00:00:00 2001 From: LokeZhou Date: Wed, 5 Jul 2023 16:42:12 +0800 Subject: [PATCH] add some loss api (#109) * add some loss api * unify LossMatcher * fix mseloss * fix unfold bug * fix Tuple2ListMatcher * fix some comments --- paconvert/api_mapping.json | 77 +++++ paconvert/api_matcher.py | 21 ++ tests/test_nn_BCELoss.py | 146 ++++++++++ tests/test_nn_L1Loss.py | 139 +++++++++ tests/test_nn_MSELoss.py | 139 +++++++++ tests/test_nn_Unfold.py | 271 ++++++++++++++++++ ...test_nn_functional_binary_cross_entropy.py | 139 +++++++++ tests/test_nn_functional_unfold.py | 262 +++++++++++++++++ 8 files changed, 1194 insertions(+) create mode 100644 tests/test_nn_BCELoss.py create mode 100644 tests/test_nn_L1Loss.py create mode 100644 tests/test_nn_MSELoss.py create mode 100644 tests/test_nn_Unfold.py create mode 100644 tests/test_nn_functional_binary_cross_entropy.py create mode 100644 tests/test_nn_functional_unfold.py diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index b57a58e05..786b9dee4 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -5504,6 +5504,16 @@ "divisor_override" ] }, + "torch.nn.BCELoss": { + "Matcher": "SizeAverageMatcher", + "paddle_api": "paddle.nn.BCELoss", + "args_list": [ + "weight", + "size_average", + "reduce", + "reduction" + ] + }, "torch.nn.BCEWithLogitsLoss": { "Matcher": "SizeAverageMatcher", "paddle_api": "paddle.nn.BCEWithLogitsLoss", @@ -6010,6 +6020,15 @@ "dtype": "" } }, + "torch.nn.L1Loss": { + "Matcher": "SizeAverageMatcher", + "paddle_api": "paddle.nn.L1Loss", + "args_list": [ + "size_average", + "reduce", + "reduction" + ] + }, "torch.nn.LSTM": { "Matcher": "RNNMatcher", "paddle_api": "paddle.nn.LSTM", @@ -6115,6 +6134,15 @@ "dim": "axis" } }, + "torch.nn.MSELoss": { + "Matcher": "SizeAverageMatcher", + "paddle_api": "paddle.nn.MSELoss", + "args_list": [ + "size_average", + "reduce", + "reduction" + ] + }, "torch.nn.MaxPool1d": { "Matcher": "MaxPoolMatcher", "paddle_api": "paddle.nn.MaxPool1D", @@ -6746,6 +6774,22 @@ "unflattened_size": "shape" } }, + "torch.nn.Unfold": { + "Matcher": "Tuple2ListMatcher", + "paddle_api": "paddle.nn.Unfold", + "args_list": [ + "kernel_size", + "dilation", + "padding", + "stride" + ], + "kwargs_change": { + "kernel_size": "kernel_sizes", + "dilation": "dilations", + "padding": "paddings", + "stride": "strides" + } + }, "torch.nn.Upsample": { "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.Upsample", @@ -6951,6 +6995,21 @@ "input2": "x2" } }, + "torch.nn.functional.binary_cross_entropy": { + "Matcher": "SizeAverageMatcher", + "paddle_api": "paddle.nn.functional.binary_cross_entropy", + "args_list": [ + "input", + "target", + "weight", + "size_average", + "reduce", + "reduction" + ], + "kwargs_change": { + "target": "label" + } + }, "torch.nn.functional.binary_cross_entropy_with_logits": { "Matcher": "SizeAverageMatcher", "paddle_api": "paddle.nn.functional.binary_cross_entropy_with_logits", @@ -7852,6 +7911,24 @@ "reduction" ] }, + "torch.nn.functional.unfold": { + "Matcher": "Tuple2ListMatcher", + "paddle_api": "paddle.nn.functional.unfold", + "args_list": [ + "input", + "kernel_size", + "dilation", + "padding", + "stride" + ], + "kwargs_change": { + "input": "x", + "kernel_size": "kernel_sizes", + "dilation": "dilations", + "padding": "paddings", + "stride": "strides" + } + }, "torch.nn.functional.upsample": { "Matcher": "GenericMatcher", "paddle_api": "paddle.nn.functional.upsample", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 9e5e0a3a7..81512b1b1 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3210,6 +3210,27 @@ def generate_code(self, kwargs): return GenericMatcher.generate_code(self, kwargs) +class Tuple2ListMatcher(BaseMatcher): + def generate_code(self, kwargs): + new_kwargs = {} + kwargs_change = self.api_mapping["kwargs_change"] + for k in list(kwargs.keys()): + if k in kwargs_change: + if "," in kwargs[k]: + new_kwargs[kwargs_change[k]] = "list({})".format(kwargs[k]) + else: + new_kwargs[kwargs_change[k]] = kwargs[k] + else: + if "," in kwargs[k]: + new_kwargs[k] = "list({})".format(kwargs[k]) + else: + new_kwargs[k] = kwargs[k] + + code = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(new_kwargs)) + + return code + + class ParameterMatcher(BaseMatcher): def get_paddle_nodes(self, args, kwargs): kwargs = self.parse_args_and_kwargs(args, kwargs) diff --git a/tests/test_nn_BCELoss.py b/tests/test_nn_BCELoss.py new file mode 100644 index 000000000..15a75ccee --- /dev/null +++ b/tests/test_nn_BCELoss.py @@ -0,0 +1,146 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.BCELoss") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + loss = torch.nn.BCELoss(weight=weight,size_average=True) + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + loss = torch.nn.BCELoss(weight=weight,size_average=False) + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + loss = torch.nn.BCELoss(weight=weight,reduction='none') + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + loss = torch.nn.BCELoss(weight=weight,reduction='mean') + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + loss = torch.nn.BCELoss(weight=weight,reduction='sum') + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + loss = torch.nn.BCELoss(weight=weight,reduce=True) + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + loss = torch.nn.BCELoss(weight=weight,reduce=False) + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + loss = torch.nn.BCELoss() + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_L1Loss.py b/tests/test_nn_L1Loss.py new file mode 100644 index 000000000..5d821465c --- /dev/null +++ b/tests/test_nn_L1Loss.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.L1Loss") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1.,2.,3.],[4.,5.,6.]]) + loss = torch.nn.L1Loss(size_average=True) + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1.,2.,3.],[4.,5.,6.]]) + loss = torch.nn.L1Loss(size_average=False) + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1.,2.,3.],[4.,5.,6.]]) + loss = torch.nn.L1Loss(reduction='none') + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1.,2.,3.],[4.,5.,6.]]) + loss = torch.nn.L1Loss(reduction='mean') + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1.,2.,3.],[4.,5.,6.]]) + loss = torch.nn.L1Loss(reduction='sum') + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1.,2.,3.],[4.,5.,6.]]) + loss = torch.nn.L1Loss(reduce=True) + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1.,2.,3.],[4.,5.,6.]]) + loss = torch.nn.L1Loss(reduce=False) + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1.,2.,3.],[4.,5.,6.]]) + loss = torch.nn.L1Loss() + result = loss(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_MSELoss.py b/tests/test_nn_MSELoss.py new file mode 100644 index 000000000..93f98f8d2 --- /dev/null +++ b/tests/test_nn_MSELoss.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.MSELoss") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1., 2., 1.],[1., 2., 3.]]) + loss = torch.nn.MSELoss(size_average=True) + result = loss(input, target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1., 2., 1.],[1., 2., 3.]]) + loss = torch.nn.MSELoss(size_average=False) + result = loss(input, target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1., 2., 1.],[1., 2., 3.]]) + loss = torch.nn.MSELoss(reduction='none') + result = loss(input, target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1., 2., 1.],[1., 2., 3.]]) + loss = torch.nn.MSELoss(reduction='mean') + result = loss(input, target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1., 2., 1.],[1., 2., 3.]]) + loss = torch.nn.MSELoss(reduction='sum') + result = loss(input, target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1., 2., 1.],[1., 2., 3.]]) + loss = torch.nn.MSELoss(reduce=True) + result = loss(input, target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1., 2., 1.],[1., 2., 3.]]) + loss = torch.nn.MSELoss(reduce=False) + result = loss(input, target) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[-1.2837, -0.0297, 0.0355], + [ 0.9112, -1.7526, -0.4061]]) + target = torch.tensor([[1., 2., 1.],[1., 2., 3.]]) + loss = torch.nn.MSELoss() + result = loss(input, target) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_Unfold.py b/tests/test_nn_Unfold.py new file mode 100644 index 000000000..c737798bf --- /dev/null +++ b/tests/test_nn_Unfold.py @@ -0,0 +1,271 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.Unfold") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + + unfold = torch.nn.Unfold(kernel_size=(2,2)) + result = unfold(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + + unfold = torch.nn.Unfold(kernel_size=(2, 3),padding=1) + result = unfold(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + + unfold = torch.nn.Unfold(kernel_size=(2, 3),padding=(2,2)) + result = unfold(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + + unfold = torch.nn.Unfold(kernel_size=(2, 3),dilation=(1,1)) + result = unfold(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + + unfold = torch.nn.Unfold(kernel_size=(2, 3),stride=(2,2)) + result = unfold(input) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + + unfold = torch.nn.Unfold(kernel_size=(2, 3),stride=(2,2),dilation=(1,1),padding=(2,2)) + result = unfold(input) + + """ + ) + obj.run(pytorch_code, ["result"]) + + +def _test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + kernel_size=(2,2) + stride=(2,2) + dilation=(1,1) + padding=(1,1) + + unfold = torch.nn.Unfold(kernel_size=kernel_size,stride=stride,dilation=dilation,padding=padding) + result = unfold(input) + + """ + ) + obj.run( + pytorch_code, + unsupport=True, + reason="Unable to determine whether the variable is an tuple or a list", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + + + unfold = torch.nn.Unfold(kernel_size=[2,2],stride=[2,2],dilation=[1,1],padding=[1,1]) + result = unfold(input) + + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_functional_binary_cross_entropy.py b/tests/test_nn_functional_binary_cross_entropy.py new file mode 100644 index 000000000..0a84ca525 --- /dev/null +++ b/tests/test_nn_functional_binary_cross_entropy.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.functional.binary_cross_entropy") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + result = torch.nn.functional.binary_cross_entropy(input,target,weight=weight,size_average=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + weight = torch.tensor([0.5,0.2,0.3]) + result = torch.nn.functional.binary_cross_entropy(input,target,weight=weight,size_average=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + result = torch.nn.functional.binary_cross_entropy(input,target,weight=weight,reduction='none') + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + result = torch.nn.functional.binary_cross_entropy(input,target,weight=weight,reduction='mean') + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + result = torch.nn.functional.binary_cross_entropy(input,target,weight=weight,reduction='sum') + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + result = torch.nn.functional.binary_cross_entropy(input,target,weight=weight,reduce=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + weight = torch.tensor([0.5,0.2,0.3]) + result = torch.nn.functional.binary_cross_entropy(input,target,weight=weight,reduce=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + input = torch.tensor([[0.2837, 0.0297, 0.0355], + [ 0.9112, 0.7526, 0.4061]]) + target = torch.tensor([[1.,0.,1.],[0.,1.,0.]]) + result = torch.nn.functional.binary_cross_entropy(input,target) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_nn_functional_unfold.py b/tests/test_nn_functional_unfold.py new file mode 100644 index 000000000..6eb014dfe --- /dev/null +++ b/tests/test_nn_functional_unfold.py @@ -0,0 +1,262 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nn.functional.unfold") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + result = torch.nn.functional.unfold(input,kernel_size=(2, 3)) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + + result = torch.nn.functional.unfold(input,kernel_size=(2, 3),padding=1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + + result = torch.nn.functional.unfold(input,kernel_size=(2, 3),padding=(2,2)) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + + result = torch.nn.functional.unfold(input,kernel_size=(2, 3),dilation=(1,1)) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + + result = torch.nn.functional.unfold(input,kernel_size=(2, 3),stride=(2,2)) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + + result = torch.nn.functional.unfold(input,kernel_size=(2, 3),stride=(2,2),dilation=(1,1),padding=(2,2)) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def _test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + kernel_size=(2,2) + stride=(2,2) + dilation=(1,1) + padding=(1,1) + + result = torch.nn.functional.unfold(input,kernel_size=kernel_size,stride=stride,dilation=dilation,padding=padding) + + """ + ) + obj.run( + pytorch_code, + unsupport=True, + reason="Unable to determine whether the variable is an tuple or a list", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor( + [[[[0.5018016, 0.71745074, 0.02612579, 0.04813039], + [0.14209914, 0.45702428, 0.06756079, 0.73914427], + [0.35131782, 0.03954667, 0.1214295, 0.25422984]], + + [[0.3040169, 0.650879, 0.29451096, 0.4443251 ], + [0.00550938, 0.38386834, 0.48462474, 0.49691153], + [0.9952472, 0.05594945, 0.6351355, 0.6343607 ]]], + + + [[[0.37795508, 0.63193935, 0.19294626, 0.77718097], + [0.785048, 0.67698157, 0.6636463, 0.63043 ], + [0.3141495, 0.48402798, 0.43465394, 0.52195907]], + + [[0.8227394, 0.47486508, 0.41936857, 0.08142513], + [0.518088, 0.5427299, 0.9754643, 0.58517313], + [0.0467307, 0.18104774, 0.9747845, 0.84306306]]]] + ) + + + result = torch.nn.functional.unfold(input,kernel_size=[2,2],stride=[2,2],dilation=[1,1],padding=[1,1]) + + + """ + ) + obj.run(pytorch_code, ["result"])