Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

转换规则 No. 114-120 #122

Merged
merged 5 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -2851,6 +2851,107 @@
"input": "x"
}
},
"torch.autograd.backward": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.autograd.backward",
"args_list": [
"tensors",
"grad_tensors",
"retain_graph",
"create_graph",
"grad_variables",
"inputs"
],
"unsupport_args": [
"create_graph",
"grad_variables",
"inputs"
]
},
"torch.autograd.functional.hessian": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.incubate.autograd.Hessian",
"args_list": [
"func",
"inputs",
"create_graph",
"strict",
"vectorize",
"outer_jacobian_strategy"
],
"unsupport_args": [
"create_graph",
"strict",
"vectorize",
"outer_jacobian_strategy"
],
"kwargs_change": {
"inputs": "xs"
},
"paddle_default_kwargs": {
"is_batched": false
}
},
"torch.autograd.functional.jacobian": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.incubate.autograd.Jacobian",
"args_list": [
"func",
"inputs",
"create_graph",
"strict",
"vectorize",
"strategy"
],
"unsupport_args": [
"create_graph",
"strict",
"vectorize",
"strategy"
],
"kwargs_change": {
"inputs": "xs"
},
"paddle_default_kwargs": {
"is_batched": false
}
},
"torch.autograd.functional.jvp": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.incubate.autograd.jvp",
"args_list": [
"func",
"inputs",
"v",
"create_graph",
"strict"
],
"unsupport_args": [
"create_graph",
"strict"
],
"kwargs_change": {
"inputs": "xs"
}
},
"torch.autograd.functional.vjp": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.incubate.autograd.vjp",
"args_list": [
"func",
"inputs",
"v",
"create_graph",
"strict"
],
"unsupport_args": [
"create_graph",
"strict"
],
"kwargs_change": {
"inputs": "xs"
}
},
"torch.autograd.grad": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.grad",
Expand Down
77 changes: 77 additions & 0 deletions tests/test_autograd_backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.autograd.backward")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, requires_grad=True)
y = torch.tensor([[3, 2], [3, 4]], dtype=torch.float32)
grad_tensor1 = torch.tensor([[1,2], [2, 3]], dtype=torch.float32)
grad_tensor2 = torch.tensor([[1,1], [1, 1]], dtype=torch.float32)
z1 = torch.matmul(x, y)
z2 = torch.matmul(x, y)
torch.autograd.backward([z1, z2], [grad_tensor1, grad_tensor2], True)
x.grad.requires_grad=False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个最后是不是不设置.requires_grad,能比较不

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

返回 requires_grad 属性不一致错误
图片

result = x.grad
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, requires_grad=True)
y = torch.tensor([[3, 2], [3, 4]], dtype=torch.float32)
grad_tensor1 = torch.tensor([[1,2], [2, 3]], dtype=torch.float32)
grad_tensor2 = torch.tensor([[1,1], [1, 1]], dtype=torch.float32)
z1 = torch.matmul(x, y)
z2 = torch.matmul(x, y)
torch.autograd.backward([z1, z2], [grad_tensor1, grad_tensor2], retain_graph=False)
x.grad.requires_grad=False
result = x.grad
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, requires_grad=True)
z1 = x.sum()
torch.autograd.backward([z1])
x.grad.requires_grad=False
result = x.grad
"""
)
obj.run(pytorch_code, ["result"])
70 changes: 70 additions & 0 deletions tests/test_autograd_functional_hessian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import textwrap

from apibase import APIBase

obj = APIBase("torch.autograd.functional.hessian")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return torch.sum(x * x)
x = torch.rand(2, 2)
h = torch.autograd.functional.hessian(func, x)
result = h[:]
result.requires_grad = False
result = torch.flatten(result)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后面不加这些处理,直接比较h能跑过不,或者直接比较result能跑过不

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用h返回没有numpy属性错误
图片

"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return 2 * torch.sum(x * x + 3 * x)
x = torch.rand(2, 2)
h = torch.autograd.functional.hessian(func, x)
result = h[:]
result.requires_grad = False
result = torch.flatten(result)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return torch.sum(x)
x = torch.tensor([1.0, 2.0])
h = torch.autograd.functional.hessian(func, x)
result = h[:]
result.requires_grad = False
result = torch.flatten(result)
"""
)
obj.run(pytorch_code, ["result"])
69 changes: 69 additions & 0 deletions tests/test_autograd_functional_jacobian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import textwrap

from apibase import APIBase

obj = APIBase("torch.autograd.functional.jacobian")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return x * x
x = torch.tensor([1., 2.])
J = torch.autograd.functional.jacobian(func, x)
result = J[:]
result.requires_grad = False
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return torch.cos(x)
x = torch.tensor([1., 2.])
J = torch.autograd.functional.jacobian(func, x)
result = J[:]
result.requires_grad = False
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return torch.log(x)
x = torch.tensor([1., 2.])
J = torch.autograd.functional.jacobian(func, x)
result = J[:]
result.requires_grad = False
"""
)
obj.run(pytorch_code, ["result"])
Loading