From 88892643577aa506cebb841de4578758234d04fc Mon Sep 17 00:00:00 2001 From: co63oc Date: Mon, 18 Dec 2023 16:08:20 +0800 Subject: [PATCH] Fix --- paconvert/api_mapping.json | 19 +++++++++ paconvert/api_matcher.py | 2 + tests/test_histogramdd.py | 84 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+) create mode 100644 tests/test_histogramdd.py diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 3bb88859d..cfd0665fe 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -6638,6 +6638,24 @@ "out" ] }, + "torch.histogramdd": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.histogramdd", + "min_input_args": 2, + "args_list": [ + "input", + "bins", + "*", + "range", + "weight", + "density" + ], + "kwargs_change": { + "input": "x", + "range": "ranges", + "weight": "weights" + } + }, "torch.hstack": { "Matcher": "HStackMatcher", "args_list": [ @@ -12223,6 +12241,7 @@ "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts": { "Matcher": "Optim2LrSchedulerMatcher", "paddle_api": "paddle.optimizer.lr.CosineAnnealingWarmRestarts", + "min_input_args": 2, "args_list": [ "optimizer", "T_0", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 11e9421e8..356e51d26 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -2128,6 +2128,8 @@ class TensorToMatcher(BaseMatcher): def get_paddle_nodes(self, args, kwargs): new_args = self.parse_args(args) new_kwargs = self.parse_kwargs(kwargs) + if new_kwargs is None: + new_kwargs = {} if "copy" in new_kwargs: new_kwargs.pop("copy") if "memory_format" in new_kwargs: diff --git a/tests/test_histogramdd.py b/tests/test_histogramdd.py new file mode 100644 index 000000000..ed8611d46 --- /dev/null +++ b/tests/test_histogramdd.py @@ -0,0 +1,84 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.histogramdd") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0., 1.], [1., 0.], [2.,0.], [2., 2.]]) + bins = [3,3] + weights = torch.tensor([1., 2., 4., 8.]) + result = torch.histogramdd(x, bins=bins, weight=weights) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0., 1.], [1., 0.], [2.,0.], [2., 2.]]) + bins = [3,3] + weights = torch.tensor([1., 2., 4., 8.]) + result = torch.histogramdd(input=x, bins=bins, weight=weights) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0., 1.], [1., 0.], [2.,0.], [2., 2.]]) + bins = [3,3] + weights = torch.tensor([1., 2., 4., 8.]) + result = torch.histogramdd(input=x, weight=weights, bins=bins) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0., 1.], [1., 0.], [2.,0.], [2., 2.]]) + bins = [3,3] + weights = torch.tensor([1., 2., 4., 8.]) + result = torch.histogramdd(input=x, bins=bins, range=None, weight=weights, density=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0., 1.], [1., 0.], [2.,0.], [2., 2.]]) + bins = [3,3] + weights = torch.tensor([1., 2., 4., 8.]) + result = torch.histogramdd(x, bins, range=None, weight=weights, density=True) + """ + ) + obj.run(pytorch_code, ["result"])