From 9d691ccd007e699f02a2cff73c5cd0d0c894dd4e Mon Sep 17 00:00:00 2001 From: txyugood Date: Wed, 5 Jul 2023 16:15:20 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BD=AC=E6=8D=A2=E8=A7=84=E5=88=99=20No.234/2?= =?UTF-8?q?36/237=20(#133)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 转换规则 No.234/236/237 * Fix variable naming in TensorDatasetMatcher. --- paconvert/api_mapping.json | 19 ++++ paconvert/api_matcher.py | 12 +++ tests/test_utils_data_ChainDataset.py | 106 ++++++++++++++++++++++ tests/test_utils_data_Subset.py | 120 +++++++++++++++++++++++++ tests/test_utils_data_TensorDataset.py | 103 +++++++++++++++++++++ 5 files changed, 360 insertions(+) create mode 100644 tests/test_utils_data_ChainDataset.py create mode 100644 tests/test_utils_data_Subset.py create mode 100644 tests/test_utils_data_TensorDataset.py diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 0d98c93c6..a8075c746 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -9194,6 +9194,13 @@ "drop_last" ] }, + "torch.utils.data.ChainDataset": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.io.ChainDataset", + "args_list": [ + "datasets" + ] + }, "torch.utils.data.Dataset": { "Matcher": "GenericMatcher", "paddle_api": "paddle.io.Dataset" @@ -9245,6 +9252,18 @@ "data_source" ] }, + "torch.utils.data.Subset": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.io.Subset", + "args_list": [ + "dataset", + "indices" + ] + }, + "torch.utils.data.TensorDataset": { + "Matcher": "TensorDatasetMatcher", + "paddle_api": "paddle.io.TensorDataset" + }, "torch.utils.data.default_collate": { "Matcher": "GenericMatcher", "paddle_api": "paddle.io.dataloader.collate.default_collate_fn", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 50382d3d1..32e6134b3 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -3687,6 +3687,18 @@ def generate_code(self, kwargs): return code +class TensorDatasetMatcher(BaseMatcher): + def get_paddle_nodes(self, args, kwargs): + new_args = self.parse_args(args) + tensors_v = "[{}".format(new_args[0]) + for arg in new_args[1:]: + tensors_v += ", {}".format(arg) + tensors_v += "]" + code = "{}({})".format(self.get_paddle_api(), tensors_v) + node = ast.parse(code.strip("\n")).body + return node + + class TensorMaxMinMatcher(BaseMatcher): def get_paddle_class_nodes(self, func, args, kwargs): diff --git a/tests/test_utils_data_ChainDataset.py b/tests/test_utils_data_ChainDataset.py new file mode 100644 index 000000000..a3715adbd --- /dev/null +++ b/tests/test_utils_data_ChainDataset.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.utils.data.ChainDataset") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import numpy as np + import math + import torch + from torch.utils.data import IterableDataset, ChainDataset + class MyIterableDataset(torch.utils.data.IterableDataset): + def __init__(self, start, end): + super(MyIterableDataset).__init__() + assert end > start, "this example code only works with end >= start" + self.start = start + self.end = end + + def __iter__(self): + iter_start = self.start + iter_end = self.end + return iter(range(iter_start, iter_end)) + + + dataset = ChainDataset([MyIterableDataset(start=3, end=7), MyIterableDataset(start=3, end=7)]) + result = [] + for d in dataset: + result.append(d) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import numpy as np + import math + import torch + from torch.utils.data import IterableDataset, ChainDataset + class MyIterableDataset(torch.utils.data.IterableDataset): + def __init__(self, start, end): + super(MyIterableDataset).__init__() + assert end > start, "this example code only works with end >= start" + self.start = start + self.end = end + + def __iter__(self): + iter_start = self.start + iter_end = self.end + return iter(range(iter_start, iter_end)) + + + dataset = ChainDataset([MyIterableDataset(start=1, end=10), MyIterableDataset(start=1, end=3)]) + result = [] + for d in dataset: + result.append(d) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import numpy as np + import math + import torch + from torch.utils.data import IterableDataset, ChainDataset + class MyIterableDataset(torch.utils.data.IterableDataset): + def __init__(self, start, end): + super(MyIterableDataset).__init__() + assert end > start, "this example code only works with end >= start" + self.start = start + self.end = end + + def __iter__(self): + iter_start = self.start + iter_end = self.end + return iter(range(iter_start, iter_end)) + + + dataset = ChainDataset([MyIterableDataset(start=1, end=10)]) + result = [] + for d in dataset: + result.append(d) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_utils_data_Subset.py b/tests/test_utils_data_Subset.py new file mode 100644 index 000000000..df9368bbd --- /dev/null +++ b/tests/test_utils_data_Subset.py @@ -0,0 +1,120 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.utils.data.Subset") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data import Dataset, Subset + class MyDataset(Dataset): + def __init__(self, size=10): + super(Dataset).__init__() + self.data = list(range(size)) + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) + + dataset = Subset(MyDataset(10),[1, 2, 3, 4, 5, 6]) + result = [] + for d in dataset: + result.append(d) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data import Dataset, Subset + class MyDataset(Dataset): + def __init__(self, size=10): + super(Dataset).__init__() + self.data = list(range(size)) + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) + + dataset = Subset(MyDataset(10),[9, 1]) + result = [] + for d in dataset: + result.append(d) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data import Dataset, Subset + class MyDataset(Dataset): + def __init__(self, size=10): + super(Dataset).__init__() + self.data = list(range(size)) + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) + + dataset = Subset(MyDataset(10),[9, 1, 3]) + result = [] + for d in dataset: + result.append(d) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + from torch.utils.data import Dataset, Subset + class MyDataset(Dataset): + def __init__(self, size=10): + super(Dataset).__init__() + self.data = list(range(size)) + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) + data = MyDataset(10) + indices = [9, 1, 3] + dataset = Subset(data, indices) + result = [] + for d in dataset: + result.append(d) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_utils_data_TensorDataset.py b/tests/test_utils_data_TensorDataset.py new file mode 100644 index 000000000..74b037061 --- /dev/null +++ b/tests/test_utils_data_TensorDataset.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.utils.data.TensorDataset") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import numpy as np + import torch + from torch.utils.data import TensorDataset + np.random.seed(0) + input_np = np.random.random([2, 3, 4]).astype('float32') + input = torch.from_numpy(input_np) + dataset = TensorDataset(input) + result = [] + for d in dataset: + result.append(d) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import numpy as np + import torch + from torch.utils.data import TensorDataset + np.random.seed(0) + input_np = np.random.random([2, 3, 4]).astype('float32') + input = torch.from_numpy(input_np) + label_np = np.random.random([2, 1]).astype('int32') + label = torch.from_numpy(label_np) + dataset = TensorDataset(input, label) + result = [] + for d in dataset: + result.append(d) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import numpy as np + import torch + from torch.utils.data import TensorDataset + np.random.seed(0) + input_np = np.random.random([2, 3, 4]).astype('float32') + input = torch.from_numpy(input_np) + input_np2 = np.random.random([2, 5, 5]).astype('float32') + input2 = torch.from_numpy(input_np2) + label_np = np.random.random([2, 1]).astype('int32') + label = torch.from_numpy(label_np) + dataset = TensorDataset(input, input2, label) + result = [] + for d in dataset: + result.append(d) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import numpy as np + import torch + from torch.utils.data import TensorDataset + np.random.seed(0) + input_np = np.random.random([2, 3, 4]).astype('float32') + input = torch.from_numpy(input_np) + input_np2 = np.random.random([2, 5, 5]).astype('float32') + input2 = torch.from_numpy(input_np2) + label_np = np.random.random([2, 1]).astype('int32') + label = torch.from_numpy(label_np) + data = [input, input2, label] + + dataset = TensorDataset(*data) + result = [] + for d in dataset: + result.append(d) + """ + ) + obj.run(pytorch_code, ["result"])