Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Dec 21, 2023
1 parent f4b789e commit e42a5bf
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 41 deletions.
1 change: 1 addition & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -3321,6 +3321,7 @@
"min_input_args": 0,
"args_list": [
"dtype",
"dst_type",
"non_blocking"
]
},
Expand Down
6 changes: 5 additions & 1 deletion paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,7 +1563,11 @@ def generate_code(self, kwargs):
if len(kwargs) == 0:
code = f"str({self.paddleClass}.dtype)"
else:
code = f"{self.paddleClass}.astype({kwargs['dtype']})"
# For torch.nn.Module.type, torch.nn.Module.type use torch.Tensor.type
if "dst_type" in kwargs:
code = f"{self.paddleClass}.astype({kwargs['dst_type']})"
else:
code = f"{self.paddleClass}.astype({kwargs['dtype']})"
return code


Expand Down
4 changes: 2 additions & 2 deletions tests/test_Tensor_signbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
result = x.signbit()
"""
)
Expand All @@ -34,7 +34,7 @@ def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float64')
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float64)
result = x.signbit()
"""
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_nn_Module_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def test_case_1():
obj.run(pytorch_code, ["result"])


# Will match torch.nn.Module, the named parameter "dst_type" cannot be resolved.
def _test_case_2():
# Will match torch.Tensor.type to resolve "dst_type" parameter.
def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
Expand Down
31 changes: 14 additions & 17 deletions tests/test_optim_lr_scheduler_CosineAnnealingWarmRestarts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import textwrap

from apibase import APIBase
from lr_scheduler_helper import generate_torch_code
from lr_scheduler_helper import generate_lr_scheduler_test_code

obj = APIBase("torch.optim.lr_scheduler.CosineAnnealingWarmRestarts")


def test_case_1():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, 10)"
)
)
Expand All @@ -31,58 +31,55 @@ def test_case_1():

def test_case_2():
pytorch_code = textwrap.dedent(
generate_torch_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, T_0=10)"
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, T_0=10, T_mult=1)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_3():
pytorch_code = textwrap.dedent(
generate_torch_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10)"
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(T_mult=1, optimizer=sgd, T_0=10)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_4():
pytorch_code = textwrap.dedent(
generate_torch_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.0, last_epoch=-1, verbose=True)"
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, T_mult=1, eta_min=0.0, last_epoch=-1, verbose=True)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_5():
pytorch_code = textwrap.dedent(
generate_torch_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.05, verbose=True)"
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, T_mult=1, eta_min=0.05, verbose=True)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_6():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, 10, 1, 1.0, -1, False)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


# reference: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/optimizer/lr/CosineAnnealingDecay_en.html
# note: paddle not support restart
# paddle result has diff with pytorch result
def test_case_7():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
[
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.0, last_epoch=-1, verbose=False)",
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.0, last_epoch=scheduler_1.last_epoch, verbose=False)",
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, T_mult=1, eta_min=0.0, last_epoch=-1, verbose=False)",
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, T_mult=2, eta_min=0.0, last_epoch=scheduler_1.last_epoch, verbose=False)",
]
)
)
Expand Down
28 changes: 17 additions & 11 deletions tests/test_optim_lr_scheduler_LinearLR.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@
import textwrap

from apibase import APIBase
from lr_scheduler_helper import generate_torch_code
from lr_scheduler_helper import generate_lr_scheduler_test_code

obj = APIBase("torch.optim.lr_scheduler.LinearLR")


def test_case_1():
pytorch_code = textwrap.dedent(
generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd, verbose=True)")
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(sgd, verbose=True)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_2():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(sgd, start_factor=0.05, end_factor=1.0)"
)
)
Expand All @@ -38,21 +40,25 @@ def test_case_2():

def test_case_3():
pytorch_code = textwrap.dedent(
generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd, total_iters=3)")
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(sgd, total_iters=3)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_4():
pytorch_code = textwrap.dedent(
generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd, 0.05, 1)")
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(sgd, 0.05, 1)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_5():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3)"
)
)
Expand All @@ -61,7 +67,7 @@ def test_case_5():

def test_case_6():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(start_factor=0.05, end_factor=1.0, total_iters=3, optimizer=sgd)"
)
)
Expand All @@ -70,7 +76,7 @@ def test_case_6():

def test_case_7():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(sgd, 0.05, 1.0, 3, -1, False)"
)
)
Expand All @@ -79,7 +85,7 @@ def test_case_7():

def test_case_8():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3, last_epoch=-1, verbose=False)"
)
)
Expand All @@ -88,7 +94,7 @@ def test_case_8():

def test_case_9():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
[
"torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3, last_epoch=-1, verbose=False)",
"torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3, last_epoch=scheduler_1.last_epoch, verbose=False)",
Expand All @@ -100,6 +106,6 @@ def test_case_9():

def test_case_10():
pytorch_code = textwrap.dedent(
generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd)")
generate_lr_scheduler_test_code("torch.optim.lr_scheduler.LinearLR(sgd)")
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)
16 changes: 8 additions & 8 deletions tests/test_signbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
result = torch.signbit(x)
"""
)
Expand All @@ -34,7 +34,7 @@ def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
result = torch.signbit(input=x)
"""
)
Expand All @@ -45,8 +45,8 @@ def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
out = torch.tensor([])
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
out = torch.tensor([], dtype=torch.bool)
result = torch.signbit(out=out, input=x)
"""
)
Expand All @@ -57,8 +57,8 @@ def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
out = torch.tensor([])
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
out = torch.tensor([], dtype=torch.bool)
result = torch.signbit(input=x, out=out)
"""
)
Expand All @@ -69,8 +69,8 @@ def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
out = torch.tensor([])
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
out = torch.tensor([], dtype=torch.bool)
result = torch.signbit(x, out=out)
"""
)
Expand Down

0 comments on commit e42a5bf

Please sign in to comment.