Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

规则转换 No.234/236/237 #131

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -8613,6 +8613,30 @@
"data_source"
]
},
"torch.utils.data.TensorDataset": {
"Matcher": "TensorDatasetMatcher",
"paddle_api": "paddle.io.TensorDataset"
},
"torch.utils.data.IterableDataset": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.io.IterableDataset",
"args_list": []
},
"torch.utils.data.ChainDataset": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.io.ChainDataset",
"args_list": [
"datasets"
]
},
"torch.utils.data.Subset": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.io.Subset",
"args_list": [
"dataset",
"indices"
]
},
"torch.utils.data.random_split": {
"Matcher": "RandomSplitMatcher",
"paddle_api": "paddle.io.random_split",
Expand Down
12 changes: 12 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3680,3 +3680,15 @@ def generate_code(self, kwargs):
self.kwargs_to_str(kwargs),
)
return code


class TensorDatasetMatcher(BaseMatcher):
def get_paddle_nodes(self, args, kwargs):
new_args = self.parse_args(args)
code = "[{}".format(new_args[0])
for arg in new_args[1:]:
code += ", {}".format(arg)
code += "]"
code = "{}({})".format(self.get_paddle_api(), code)
node = ast.parse(code.strip("\n")).body
return node
106 changes: 106 additions & 0 deletions tests/test_utils_data_ChainDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.utils.data.ChainDataset")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import numpy as np
import math
import torch
from torch.utils.data import IterableDataset, ChainDataset
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end

def __iter__(self):
iter_start = self.start
iter_end = self.end
return iter(range(iter_start, iter_end))


dataset = ChainDataset([MyIterableDataset(start=3, end=7), MyIterableDataset(start=3, end=7)])
result = []
for d in dataset:
result.append(d)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import numpy as np
import math
import torch
from torch.utils.data import IterableDataset, ChainDataset
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end

def __iter__(self):
iter_start = self.start
iter_end = self.end
return iter(range(iter_start, iter_end))


dataset = ChainDataset([MyIterableDataset(start=1, end=10), MyIterableDataset(start=1, end=3)])
result = []
for d in dataset:
result.append(d)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import numpy as np
import math
import torch
from torch.utils.data import IterableDataset, ChainDataset
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end

def __iter__(self):
iter_start = self.start
iter_end = self.end
return iter(range(iter_start, iter_end))


dataset = ChainDataset([MyIterableDataset(start=1, end=10)])
result = []
for d in dataset:
result.append(d)
"""
)
obj.run(pytorch_code, ["result"])
120 changes: 120 additions & 0 deletions tests/test_utils_data_Subset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.utils.data.Subset")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
from torch.utils.data import Dataset, Subset
class MyDataset(Dataset):
def __init__(self, size=10):
super(Dataset).__init__()
self.data = list(range(size))

def __getitem__(self, idx):
return self.data[idx]

def __len__(self):
return len(self.data)

dataset = Subset(MyDataset(10),[1, 2, 3, 4, 5, 6])
result = []
for d in dataset:
result.append(d)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
from torch.utils.data import Dataset, Subset
class MyDataset(Dataset):
def __init__(self, size=10):
super(Dataset).__init__()
self.data = list(range(size))

def __getitem__(self, idx):
return self.data[idx]

def __len__(self):
return len(self.data)

dataset = Subset(MyDataset(10),[9, 1])
result = []
for d in dataset:
result.append(d)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
from torch.utils.data import Dataset, Subset
class MyDataset(Dataset):
def __init__(self, size=10):
super(Dataset).__init__()
self.data = list(range(size))

def __getitem__(self, idx):
return self.data[idx]

def __len__(self):
return len(self.data)

dataset = Subset(MyDataset(10),[9, 1, 3])
result = []
for d in dataset:
result.append(d)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
from torch.utils.data import Dataset, Subset
class MyDataset(Dataset):
def __init__(self, size=10):
super(Dataset).__init__()
self.data = list(range(size))

def __getitem__(self, idx):
return self.data[idx]

def __len__(self):
return len(self.data)
data = MyDataset(10)
indices = [9, 1, 3]
dataset = Subset(data, indices)
result = []
for d in dataset:
result.append(d)
"""
)
obj.run(pytorch_code, ["result"])
103 changes: 103 additions & 0 deletions tests/test_utils_data_TensorDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.utils.data.TensorDataset")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import numpy as np
import torch
from torch.utils.data import TensorDataset
np.random.seed(0)
input_np = np.random.random([2, 3, 4]).astype('float32')
input = torch.from_numpy(input_np)
dataset = TensorDataset(input)
result = []
for d in dataset:
result.append(d)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import numpy as np
import torch
from torch.utils.data import TensorDataset
np.random.seed(0)
input_np = np.random.random([2, 3, 4]).astype('float32')
input = torch.from_numpy(input_np)
label_np = np.random.random([2, 1]).astype('int32')
label = torch.from_numpy(label_np)
dataset = TensorDataset(input, label)
result = []
for d in dataset:
result.append(d)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import numpy as np
import torch
from torch.utils.data import TensorDataset
np.random.seed(0)
input_np = np.random.random([2, 3, 4]).astype('float32')
input = torch.from_numpy(input_np)
input_np2 = np.random.random([2, 5, 5]).astype('float32')
input2 = torch.from_numpy(input_np2)
label_np = np.random.random([2, 1]).astype('int32')
label = torch.from_numpy(label_np)
dataset = TensorDataset(input, input2, label)
result = []
for d in dataset:
result.append(d)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import numpy as np
import torch
from torch.utils.data import TensorDataset
np.random.seed(0)
input_np = np.random.random([2, 3, 4]).astype('float32')
input = torch.from_numpy(input_np)
input_np2 = np.random.random([2, 5, 5]).astype('float32')
input2 = torch.from_numpy(input_np2)
label_np = np.random.random([2, 1]).astype('int32')
label = torch.from_numpy(label_np)
data = [input, input2, label]

dataset = TensorDataset(*data)
result = []
for d in dataset:
result.append(d)
"""
)
obj.run(pytorch_code, ["result"])