Skip to content

Commit

Permalink
add some loss api (#109)
Browse files Browse the repository at this point in the history
* add some loss api

* unify LossMatcher

* fix mseloss

* fix unfold bug

* fix Tuple2ListMatcher

* fix some comments
  • Loading branch information
LokeZhou authored Jul 5, 2023
1 parent b11376a commit 182bd89
Show file tree
Hide file tree
Showing 8 changed files with 1,194 additions and 0 deletions.
77 changes: 77 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -5504,6 +5504,16 @@
"divisor_override"
]
},
"torch.nn.BCELoss": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.BCELoss",
"args_list": [
"weight",
"size_average",
"reduce",
"reduction"
]
},
"torch.nn.BCEWithLogitsLoss": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.BCEWithLogitsLoss",
Expand Down Expand Up @@ -6010,6 +6020,15 @@
"dtype": ""
}
},
"torch.nn.L1Loss": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.L1Loss",
"args_list": [
"size_average",
"reduce",
"reduction"
]
},
"torch.nn.LSTM": {
"Matcher": "RNNMatcher",
"paddle_api": "paddle.nn.LSTM",
Expand Down Expand Up @@ -6115,6 +6134,15 @@
"dim": "axis"
}
},
"torch.nn.MSELoss": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.MSELoss",
"args_list": [
"size_average",
"reduce",
"reduction"
]
},
"torch.nn.MaxPool1d": {
"Matcher": "MaxPoolMatcher",
"paddle_api": "paddle.nn.MaxPool1D",
Expand Down Expand Up @@ -6746,6 +6774,22 @@
"unflattened_size": "shape"
}
},
"torch.nn.Unfold": {
"Matcher": "Tuple2ListMatcher",
"paddle_api": "paddle.nn.Unfold",
"args_list": [
"kernel_size",
"dilation",
"padding",
"stride"
],
"kwargs_change": {
"kernel_size": "kernel_sizes",
"dilation": "dilations",
"padding": "paddings",
"stride": "strides"
}
},
"torch.nn.Upsample": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Upsample",
Expand Down Expand Up @@ -6951,6 +6995,21 @@
"input2": "x2"
}
},
"torch.nn.functional.binary_cross_entropy": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.functional.binary_cross_entropy",
"args_list": [
"input",
"target",
"weight",
"size_average",
"reduce",
"reduction"
],
"kwargs_change": {
"target": "label"
}
},
"torch.nn.functional.binary_cross_entropy_with_logits": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.functional.binary_cross_entropy_with_logits",
Expand Down Expand Up @@ -7852,6 +7911,24 @@
"reduction"
]
},
"torch.nn.functional.unfold": {
"Matcher": "Tuple2ListMatcher",
"paddle_api": "paddle.nn.functional.unfold",
"args_list": [
"input",
"kernel_size",
"dilation",
"padding",
"stride"
],
"kwargs_change": {
"input": "x",
"kernel_size": "kernel_sizes",
"dilation": "dilations",
"padding": "paddings",
"stride": "strides"
}
},
"torch.nn.functional.upsample": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.upsample",
Expand Down
21 changes: 21 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3210,6 +3210,27 @@ def generate_code(self, kwargs):
return GenericMatcher.generate_code(self, kwargs)


class Tuple2ListMatcher(BaseMatcher):
def generate_code(self, kwargs):
new_kwargs = {}
kwargs_change = self.api_mapping["kwargs_change"]
for k in list(kwargs.keys()):
if k in kwargs_change:
if "," in kwargs[k]:
new_kwargs[kwargs_change[k]] = "list({})".format(kwargs[k])
else:
new_kwargs[kwargs_change[k]] = kwargs[k]
else:
if "," in kwargs[k]:
new_kwargs[k] = "list({})".format(kwargs[k])
else:
new_kwargs[k] = kwargs[k]

code = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(new_kwargs))

return code


class ParameterMatcher(BaseMatcher):
def get_paddle_nodes(self, args, kwargs):
kwargs = self.parse_args_and_kwargs(args, kwargs)
Expand Down
146 changes: 146 additions & 0 deletions tests/test_nn_BCELoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.BCELoss")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,size_average=True)
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,size_average=False)
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,reduction='none')
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,reduction='mean')
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,reduction='sum')
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,reduce=True)
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
weight = torch.tensor([0.5,0.2,0.3])
loss = torch.nn.BCELoss(weight=weight,reduce=False)
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[0.2837, 0.0297, 0.0355],
[ 0.9112, 0.7526, 0.4061]])
target = torch.tensor([[1.,0.,1.],[0.,1.,0.]])
loss = torch.nn.BCELoss()
result = loss(input,target)
"""
)
obj.run(pytorch_code, ["result"])
Loading

0 comments on commit 182bd89

Please sign in to comment.