Skip to content

Commit

Permalink
Upsample (#115)
Browse files Browse the repository at this point in the history
* add nn.Upsample,CUDAExtension,CppExtension,SequentialSampler,is_sparse

* fix cpp_extension ut

* fix cpp_extension ut

* fix UtilsCppExtensionMatcher ci

* fix UtilsCppExtensionMatcher

* fix cpp_extension ut

* UtilsCppExtensionMatcher bug

* UtilsCppExtensionMatcher bug

* fix UtilsCppExtensionMatcher pop

* fix cpp test

* add cpp test

* add Attribute2Func
  • Loading branch information
LokeZhou authored Jul 5, 2023
1 parent 748524a commit fafdd56
Show file tree
Hide file tree
Showing 8 changed files with 420 additions and 1 deletion.
42 changes: 42 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -6732,6 +6732,19 @@
"unflattened_size": "shape"
}
},
"torch.nn.Upsample": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Upsample",
"args_list": [
"size",
"scale_factor",
"mode",
"align_corners"
],
"unsupport_args": [
"recompute_scale_factor"
]
},
"torch.nn.UpsamplingBilinear2d": {
"Matcher": "UpsampleMatcher",
"paddle_api": "paddle.nn.UpsamplingBilinear2D",
Expand Down Expand Up @@ -9147,10 +9160,32 @@
"Matcher": "GenericMatcher",
"paddle_api": "paddle.utils.cpp_extension.BuildExtension.with_options"
},
"torch.utils.cpp_extension.CUDAExtension": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.utils.cpp_extension.CUDAExtension",
"args_list": [
"name",
"sources"
],
"kwargs_change": {
"name": ""
}
},
"torch.utils.cpp_extension.CUDA_HOME": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.utils.cpp_extension.cpp_extension.CUDA_HOME"
},
"torch.utils.cpp_extension.CppExtension": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.utils.cpp_extension.CppExtension",
"args_list": [
"name",
"sources"
],
"kwargs_change": {
"name": ""
}
},
"torch.utils.data.BatchSampler": {
"Matcher": "TorchUtilDataBatchSampler",
"args_list": [
Expand Down Expand Up @@ -9203,6 +9238,13 @@
"data_source"
]
},
"torch.utils.data.SequentialSampler": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.io.SequenceSampler",
"args_list": [
"data_source"
]
},
"torch.utils.data.default_collate": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.io.dataloader.collate.default_collate_fn",
Expand Down
7 changes: 7 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3572,6 +3572,13 @@ def generate_code(self, kwargs):
return GenericMatcher.generate_code(self, kwargs)


class Attribute2Func(BaseMatcher):
def get_paddle_class_attribute_nodes(self, node):
self.parse_func(node)
code = "{}()".format(self.paddle_api)
return ast.parse(code).body[0].value


class LuMatcher(BaseMatcher):
def generate_code(self, kwargs):
out_v = kwargs.pop("out") if "out" in kwargs else None
Expand Down
5 changes: 4 additions & 1 deletion paconvert/attribute_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
},
"torch.Tensor.is_meta": {},
"torch.Tensor.is_quantized": {},
"torch.Tensor.is_sparse": {},
"torch.Tensor.is_sparse": {
"Matcher": "Attribute2Func",
"paddle_api": "paddle.Tensor.is_sparse"
},
"torch.Tensor.mH": {},
"torch.Tensor.mT": {},
"torch.Tensor.names": {},
Expand Down
29 changes: 29 additions & 0 deletions tests/test_Tensor_is_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import textwrap

from apibase import APIBase

obj = APIBase("torch.Tensor.is_sparse")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.tensor([[ 0.9254, -0.6213]])
result = a.is_sparse
"""
)
obj.run(pytorch_code, ["result"])
129 changes: 129 additions & 0 deletions tests/test_nn_Upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.Upsample")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857],
[-1.2533, -0.9829, -1.0981],
[ 0.1507, -1.1431, -2.0361]],
[[ 0.1024, -0.4482, 0.4137],
[ 0.9385, 0.4565, 0.7702],
[ 0.4135, -0.2587, 0.0482]]]])
m = torch.nn.Upsample(scale_factor=2, mode='nearest')
result = m(input)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857],
[-1.2533, -0.9829, -1.0981],
[ 0.1507, -1.1431, -2.0361]],
[[ 0.1024, -0.4482, 0.4137],
[ 0.9385, 0.4565, 0.7702],
[ 0.4135, -0.2587, 0.0482]]]])
m = torch.nn.Upsample(scale_factor=2, mode='bilinear')
result = m(input)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857],
[-1.2533, -0.9829, -1.0981],
[ 0.1507, -1.1431, -2.0361]],
[[ 0.1024, -0.4482, 0.4137],
[ 0.9385, 0.4565, 0.7702],
[ 0.4135, -0.2587, 0.0482]]]])
m = torch.nn.Upsample(scale_factor=2, mode='bilinear',align_corners=True)
result = m(input)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857],
[-1.2533, -0.9829, -1.0981],
[ 0.1507, -1.1431, -2.0361]],
[[ 0.1024, -0.4482, 0.4137],
[ 0.9385, 0.4565, 0.7702],
[ 0.4135, -0.2587, 0.0482]]]])
m = torch.nn.Upsample(size=(2,2))
result = m(input)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857],
[-1.2533, -0.9829, -1.0981],
[ 0.1507, -1.1431, -2.0361]],
[[ 0.1024, -0.4482, 0.4137],
[ 0.9385, 0.4565, 0.7702],
[ 0.4135, -0.2587, 0.0482]]]])
m = torch.nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False)
result = m(input)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857],
[-1.2533, -0.9829, -1.0981],
[ 0.1507, -1.1431, -2.0361]],
[[ 0.1024, -0.4482, 0.4137],
[ 0.9385, 0.4565, 0.7702],
[ 0.4135, -0.2587, 0.0482]]]])
m = torch.nn.Upsample(scale_factor=2, mode='bilinear',recompute_scale_factor=True)
result = m(input)
"""
)
obj.run(
pytorch_code, unsupport=True, reason="paddle unsupport recompute_scale_factor "
)
36 changes: 36 additions & 0 deletions tests/test_utils_cpp_extension_CUDAExtension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.utils.cpp_extension.CUDAExtension")


# The cuda compile not supports
def test_case_1():
pytorch_code = textwrap.dedent(
"""
from torch.utils.cpp_extension import CUDAExtension
CUDAExtension(
name='cuda_extension',
sources=['extension.cpp', 'extension_kernel.cu'],
extra_compile_args={'cxx': ['-g'],
'nvcc': ['-O2']})
result = True
"""
)
obj.run(pytorch_code, ["result"])
35 changes: 35 additions & 0 deletions tests/test_utils_cpp_extension_CppExtension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.utils.cpp_extension.CppExtension")


# The cpp compile not supports
def test_case_1():
pytorch_code = textwrap.dedent(
"""
from torch.utils.cpp_extension import CppExtension
CppExtension(
name='cuda_extension',
sources=['extension.cpp'],
extra_compile_args=['-g'])
result = True
"""
)
obj.run(pytorch_code, ["result"])
Loading

0 comments on commit fafdd56

Please sign in to comment.