Skip to content

Commit

Permalink
Add test cases 6-27 (#140)
Browse files Browse the repository at this point in the history
* Add test case 6-27

* Fix bugs

* Fix bugs

* Fix bugs

* Fix bugs
  • Loading branch information
huajiao-hjyp authored Jun 29, 2023
1 parent 3a9ce60 commit 207a13e
Show file tree
Hide file tree
Showing 34 changed files with 2,115 additions and 48 deletions.
67 changes: 20 additions & 47 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -6083,33 +6083,27 @@
"paddle_api": "paddle.device.get_device"
},
"torch.Tensor.gt": {
"Matcher": "GenericMatcher",
"Matcher": "Num2TensorBinaryMatcher",
"paddle_api": "paddle.Tensor.greater_than",
"args_list": [
"other"
],
"kwargs_change": {
"other": "y"
}
]
},
"torch.Tensor.greater": {
"Matcher": "GenericMatcher",
"Matcher": "Num2TensorBinaryMatcher",
"paddle_api": "paddle.Tensor.greater_than",
"args_list": [
"other"
],
"kwargs_change": {
"other": "y"
}
]
},
"torch.Tensor.heaviside": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.heaviside",
"args_list": [
"values"
"data"
],
"kwargs_change": {
"other": "y"
"data": "y"
}
},
"torch.Tensor.index_select": {
Expand Down Expand Up @@ -6196,24 +6190,18 @@
}
},
"torch.Tensor.le": {
"Matcher": "GenericMatcher",
"Matcher": "Num2TensorBinaryMatcher",
"paddle_api": "paddle.Tensor.less_equal",
"args_list": [
"other"
],
"kwargs_change": {
"other": "y"
}
]
},
"torch.Tensor.less_equal": {
"Matcher": "GenericMatcher",
"Matcher": "Num2TensorBinaryMatcher",
"paddle_api": "paddle.Tensor.less_equal",
"args_list": [
"other"
],
"kwargs_change": {
"other": "y"
}
]
},
"torch.Tensor.lerp": {
"Matcher": "GenericMatcher",
Expand Down Expand Up @@ -6268,38 +6256,29 @@
}
},
"torch.Tensor.logical_and": {
"Matcher": "GenericMatcher",
"Matcher": "TensorLogicalMatcher",
"paddle_api": "paddle.Tensor.logical_and",
"args_list": [
"other"
],
"kwargs_change": {
"other": "y"
}
]
},
"torch.Tensor.logical_not": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.logical_not"
},
"torch.Tensor.logical_or": {
"Matcher": "GenericMatcher",
"Matcher": "TensorLogicalMatcher",
"paddle_api": "paddle.Tensor.logical_or",
"args_list": [
"other"
],
"kwargs_change": {
"other": "y"
}
]
},
"torch.Tensor.logical_xor": {
"Matcher": "GenericMatcher",
"Matcher": "TensorLogicalMatcher",
"paddle_api": "paddle.Tensor.logical_xor",
"args_list": [
"other"
],
"kwargs_change": {
"other": "y"
}
]
},
"torch.Tensor.logit": {
"Matcher": "GenericMatcher",
Expand All @@ -6309,24 +6288,18 @@
]
},
"torch.Tensor.lt": {
"Matcher": "GenericMatcher",
"Matcher": "Num2TensorBinaryMatcher",
"paddle_api": "paddle.Tensor.less_than",
"args_list": [
"other"
],
"kwargs_change": {
"other": "y"
}
]
},
"torch.Tensor.less": {
"Matcher": "GenericMatcher",
"Matcher": "Num2TensorBinaryMatcher",
"paddle_api": "paddle.Tensor.less_than",
"args_list": [
"other"
],
"kwargs_change": {
"other": "y"
}
]
},
"torch.Tensor.lu": {
"Matcher": "GenericMatcher",
Expand Down
10 changes: 10 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3620,3 +3620,13 @@ def generate_code(self, kwargs):
self.get_paddle_api(), self.paddleClass, self.kwargs_to_str(kwargs)
)
return code


class TensorLogicalMatcher(BaseMatcher):
def generate_code(self, kwargs):

code = "{}(y=({}).astype(({}).dtype))".format(
self.get_paddle_api(), kwargs["other"], self.paddleClass
)

return code
79 changes: 79 additions & 0 deletions tests/test_Tensor_ger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.Tensor.ger")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2, 3])
y = torch.tensor([1., 2, 3, 4])
result = x.ger(y)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2, 3])
y = torch.tensor([1., 2, 3, 4])
result = x.ger(vec2=y)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.tensor([1., 2, 3]).ger(torch.tensor([1., 2, 3, 4]))
"""
)
obj.run(pytorch_code, ["result"])


# The paddle input does not support integer type
def _test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 2, 3, 4])
result = x.ger(y)
"""
)
obj.run(pytorch_code, ["result"])


# The paddle other does not support integer type
def _test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2., 3.])
y = torch.tensor([1, 2, 3, 4])
result = x.ger(y)
"""
)
obj.run(pytorch_code, ["result"])
75 changes: 75 additions & 0 deletions tests/test_Tensor_greater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.Tensor.greater")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.tensor([[1, 2], [3, 4]]).greater(torch.tensor([[1, 1], [4, 4]]))
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[1, 2], [3, 4]])
other = torch.tensor([[1, 1], [4, 4]])
result = input.greater(other)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[1, 2], [3, 4]])
other = torch.tensor([[1, 2], [3, 4]])
result = input.greater(other=other)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[1, 2], [3, 4]])
other = torch.tensor([1, 2])
result = input.greater(other)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.tensor([[1, 2], [3, 4]]).greater(2)
"""
)
obj.run(pytorch_code, ["result"])
75 changes: 75 additions & 0 deletions tests/test_Tensor_gt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.Tensor.gt")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.tensor([[1, 2], [3, 4]]).gt(torch.tensor([[1, 1], [4, 4]]))
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[1, 2], [3, 4]])
other = torch.tensor([[1, 1], [4, 4]])
result = input.gt(other)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[1, 2], [3, 4]])
other = torch.tensor([[1, 2], [3, 4]])
result = input.gt(other=other)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
input = torch.tensor([[1, 2], [3, 4]])
other = torch.tensor([1, 2])
result = input.gt(other)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.tensor([[1, 2], [3, 4]]).gt(2)
"""
)
obj.run(pytorch_code, ["result"])
Loading

0 comments on commit 207a13e

Please sign in to comment.