Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Jul 5, 2023
1 parent 96300cb commit 8629c04
Show file tree
Hide file tree
Showing 8 changed files with 595 additions and 2 deletions.
111 changes: 111 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -2834,6 +2834,107 @@
"input": "x"
}
},
"torch.autograd.backward": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.autograd.backward",
"args_list": [
"tensors",
"grad_tensors",
"retain_graph",
"create_graph",
"grad_variables",
"inputs"
],
"unsupport_args": [
"create_graph",
"grad_variables",
"inputs"
]
},
"torch.autograd.functional.hessian": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.incubate.autograd.Hessian",
"args_list": [
"func",
"inputs",
"create_graph",
"strict",
"vectorize",
"outer_jacobian_strategy"
],
"unsupport_args": [
"create_graph",
"strict",
"vectorize",
"outer_jacobian_strategy"
],
"kwargs_change": {
"inputs": "xs"
},
"paddle_default_kwargs": {
"is_batched": false
}
},
"torch.autograd.functional.jacobian": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.incubate.autograd.Jacobian",
"args_list": [
"func",
"inputs",
"create_graph",
"strict",
"vectorize",
"strategy"
],
"unsupport_args": [
"create_graph",
"strict",
"vectorize",
"strategy"
],
"kwargs_change": {
"inputs": "xs"
},
"paddle_default_kwargs": {
"is_batched": false
}
},
"torch.autograd.functional.jvp": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.incubate.autograd.jvp",
"args_list": [
"func",
"inputs",
"v",
"create_graph",
"strict"
],
"unsupport_args": [
"create_graph",
"strict"
],
"kwargs_change": {
"inputs": "xs"
}
},
"torch.autograd.functional.vjp": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.incubate.autograd.vjp",
"args_list": [
"func",
"inputs",
"v",
"create_graph",
"strict"
],
"unsupport_args": [
"create_graph",
"strict"
],
"kwargs_change": {
"inputs": "xs"
}
},
"torch.autograd.grad": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.grad",
Expand Down Expand Up @@ -3309,6 +3410,16 @@
"device": "device"
}
},
"torch.cuda.device": {
"Matcher": "CudaDeviceMatcher",
"paddle_api": "paddle.CUDAPlace",
"args_list": [
"device"
],
"kwargs_change": {
"device": "id"
}
},
"torch.cuda.device_count": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.device.cuda.device_count"
Expand Down
41 changes: 39 additions & 2 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import ast
import re
import textwrap

import astor
Expand All @@ -28,7 +29,7 @@ def get_paddle_api(self):
return self.paddle_api
return self.api_mapping["paddle_api"]

def generate_code(self, kwargs):
def generate_code(self, kwargs, args=[]):
kwargs_change = {}
if "kwargs_change" in self.api_mapping:
kwargs_change = self.api_mapping["kwargs_change"]
Expand Down Expand Up @@ -80,7 +81,9 @@ def generate_code(self, kwargs):
if "out" in new_kwargs:
out_v = new_kwargs.pop("out")

res = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(new_kwargs))
res = "{}({})".format(
self.get_paddle_api(), self.args_and_kwargs_to_str(args, new_kwargs)
)

if dtype_v:
res += ".astype({})".format(dtype_v)
Expand Down Expand Up @@ -3777,3 +3780,37 @@ def generate_code(self, kwargs):
code = "{}".format(self.get_paddle_api())

return code


class CudaDeviceMatcher(BaseMatcher):
def generate_code(self, kwargs):
if not kwargs["device"].strip("()").isdigit():
device = kwargs["device"]
if (
"replace('cuda', 'gpu')," in device
or 'replace("cuda", "gpu"),' in device
):
m = re.search(r"\(([0-9]+)\)", device)
if m:
kwargs["device"] = m.group(1)
else:
return None
elif (
"replace('cuda', 'gpu')" in device or 'replace("cuda", "gpu")' in device
):
kwargs["device"] = 0
else:
return None

kwargs_change = {}
if "kwargs_change" in self.api_mapping:
kwargs_change = self.api_mapping["kwargs_change"]

args = []
new_kwargs = {}
for ele in kwargs:
if ele in kwargs_change and kwargs_change[ele] == "id":
args.append(kwargs[ele])
else:
new_kwargs[ele] = kwargs[ele]
return GenericMatcher.generate_code(self, new_kwargs, args)
77 changes: 77 additions & 0 deletions tests/test_autograd_backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.autograd.backward")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, requires_grad=True)
y = torch.tensor([[3, 2], [3, 4]], dtype=torch.float32)
grad_tensor1 = torch.tensor([[1,2], [2, 3]], dtype=torch.float32)
grad_tensor2 = torch.tensor([[1,1], [1, 1]], dtype=torch.float32)
z1 = torch.matmul(x, y)
z2 = torch.matmul(x, y)
torch.autograd.backward([z1, z2], [grad_tensor1, grad_tensor2], True)
x.grad.requires_grad=False
result = x.grad
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, requires_grad=True)
y = torch.tensor([[3, 2], [3, 4]], dtype=torch.float32)
grad_tensor1 = torch.tensor([[1,2], [2, 3]], dtype=torch.float32)
grad_tensor2 = torch.tensor([[1,1], [1, 1]], dtype=torch.float32)
z1 = torch.matmul(x, y)
z2 = torch.matmul(x, y)
torch.autograd.backward([z1, z2], [grad_tensor1, grad_tensor2], retain_graph=False)
x.grad.requires_grad=False
result = x.grad
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, requires_grad=True)
z1 = x.sum()
torch.autograd.backward([z1])
x.grad.requires_grad=False
result = x.grad
"""
)
obj.run(pytorch_code, ["result"])
70 changes: 70 additions & 0 deletions tests/test_autograd_functional_hessian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import textwrap

from apibase import APIBase

obj = APIBase("torch.autograd.functional.hessian")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return torch.sum(x * x)
x = torch.rand(2, 2)
h = torch.autograd.functional.hessian(func, x)
result = h[:]
result.requires_grad = False
result = torch.flatten(result)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return 2 * torch.sum(x * x + 3 * x)
x = torch.rand(2, 2)
h = torch.autograd.functional.hessian(func, x)
result = h[:]
result.requires_grad = False
result = torch.flatten(result)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return torch.sum(x)
x = torch.tensor([1.0, 2.0])
h = torch.autograd.functional.hessian(func, x)
result = h[:]
result.requires_grad = False
result = torch.flatten(result)
"""
)
obj.run(pytorch_code, ["result"])
Loading

0 comments on commit 8629c04

Please sign in to comment.