From 5e03cc49b68c72a58350b46bf6bad1d5b18b2809 Mon Sep 17 00:00:00 2001 From: co63oc Date: Thu, 21 Dec 2023 08:31:36 +0800 Subject: [PATCH] Fix --- paconvert/api_mapping.json | 1 + paconvert/api_matcher.py | 6 +++- tests/test_Tensor_signbit.py | 4 +-- tests/test_nn_Module_type.py | 4 +-- ...r_scheduler_CosineAnnealingWarmRestarts.py | 16 +++++------ tests/test_optim_lr_scheduler_LinearLR.py | 28 +++++++++++-------- tests/test_signbit.py | 16 +++++------ 7 files changed, 43 insertions(+), 32 deletions(-) diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index be111596b..8f451b1c9 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -3321,6 +3321,7 @@ "min_input_args": 0, "args_list": [ "dtype", + "dst_type", "non_blocking" ] }, diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 5eef87d9b..e407f3e69 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -1563,7 +1563,11 @@ def generate_code(self, kwargs): if len(kwargs) == 0: code = f"str({self.paddleClass}.dtype)" else: - code = f"{self.paddleClass}.astype({kwargs['dtype']})" + # For torch.nn.Module.type, torch.nn.Module.type use torch.Tensor.type + if "dst_type" in kwargs: + code = f"{self.paddleClass}.astype({kwargs['dst_type']})" + else: + code = f"{self.paddleClass}.astype({kwargs['dtype']})" return code diff --git a/tests/test_Tensor_signbit.py b/tests/test_Tensor_signbit.py index b5cad33ec..ced4f975f 100644 --- a/tests/test_Tensor_signbit.py +++ b/tests/test_Tensor_signbit.py @@ -23,7 +23,7 @@ def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32') + x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32) result = x.signbit() """ ) @@ -34,7 +34,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float64') + x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float64) result = x.signbit() """ ) diff --git a/tests/test_nn_Module_type.py b/tests/test_nn_Module_type.py index 06f743210..a38c4acfe 100644 --- a/tests/test_nn_Module_type.py +++ b/tests/test_nn_Module_type.py @@ -33,8 +33,8 @@ def test_case_1(): obj.run(pytorch_code, ["result"]) -# Will match torch.nn.Module, the named parameter "dst_type" cannot be resolved. -def _test_case_2(): +# Will match torch.Tensor.type to resolve "dst_type" parameter. +def test_case_2(): pytorch_code = textwrap.dedent( """ import torch diff --git a/tests/test_optim_lr_scheduler_CosineAnnealingWarmRestarts.py b/tests/test_optim_lr_scheduler_CosineAnnealingWarmRestarts.py index 91bfcf098..2e1b4621d 100644 --- a/tests/test_optim_lr_scheduler_CosineAnnealingWarmRestarts.py +++ b/tests/test_optim_lr_scheduler_CosineAnnealingWarmRestarts.py @@ -15,14 +15,14 @@ import textwrap from apibase import APIBase -from lr_scheduler_helper import generate_torch_code +from lr_scheduler_helper import generate_lr_scheduler_test_code obj = APIBase("torch.optim.lr_scheduler.CosineAnnealingWarmRestarts") def test_case_1(): pytorch_code = textwrap.dedent( - generate_torch_code( + generate_lr_scheduler_test_code( "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, 10)" ) ) @@ -31,7 +31,7 @@ def test_case_1(): def test_case_2(): pytorch_code = textwrap.dedent( - generate_torch_code( + generate_lr_scheduler_test_code( "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, T_0=10)" ) ) @@ -40,7 +40,7 @@ def test_case_2(): def test_case_3(): pytorch_code = textwrap.dedent( - generate_torch_code( + generate_lr_scheduler_test_code( "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10)" ) ) @@ -49,7 +49,7 @@ def test_case_3(): def test_case_4(): pytorch_code = textwrap.dedent( - generate_torch_code( + generate_lr_scheduler_test_code( "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.0, last_epoch=-1, verbose=True)" ) ) @@ -58,7 +58,7 @@ def test_case_4(): def test_case_5(): pytorch_code = textwrap.dedent( - generate_torch_code( + generate_lr_scheduler_test_code( "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.05, verbose=True)" ) ) @@ -67,7 +67,7 @@ def test_case_5(): def test_case_6(): pytorch_code = textwrap.dedent( - generate_torch_code( + generate_lr_scheduler_test_code( "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, 10, 1, 1.0, -1, False)" ) ) @@ -79,7 +79,7 @@ def test_case_6(): # paddle result has diff with pytorch result def test_case_7(): pytorch_code = textwrap.dedent( - generate_torch_code( + generate_lr_scheduler_test_code( [ "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.0, last_epoch=-1, verbose=False)", "torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.0, last_epoch=scheduler_1.last_epoch, verbose=False)", diff --git a/tests/test_optim_lr_scheduler_LinearLR.py b/tests/test_optim_lr_scheduler_LinearLR.py index f39c8ab8d..8d9c7bb37 100644 --- a/tests/test_optim_lr_scheduler_LinearLR.py +++ b/tests/test_optim_lr_scheduler_LinearLR.py @@ -15,21 +15,23 @@ import textwrap from apibase import APIBase -from lr_scheduler_helper import generate_torch_code +from lr_scheduler_helper import generate_lr_scheduler_test_code obj = APIBase("torch.optim.lr_scheduler.LinearLR") def test_case_1(): pytorch_code = textwrap.dedent( - generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd, verbose=True)") + generate_lr_scheduler_test_code( + "torch.optim.lr_scheduler.LinearLR(sgd, verbose=True)" + ) ) obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5) def test_case_2(): pytorch_code = textwrap.dedent( - generate_torch_code( + generate_lr_scheduler_test_code( "torch.optim.lr_scheduler.LinearLR(sgd, start_factor=0.05, end_factor=1.0)" ) ) @@ -38,21 +40,25 @@ def test_case_2(): def test_case_3(): pytorch_code = textwrap.dedent( - generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd, total_iters=3)") + generate_lr_scheduler_test_code( + "torch.optim.lr_scheduler.LinearLR(sgd, total_iters=3)" + ) ) obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5) def test_case_4(): pytorch_code = textwrap.dedent( - generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd, 0.05, 1)") + generate_lr_scheduler_test_code( + "torch.optim.lr_scheduler.LinearLR(sgd, 0.05, 1)" + ) ) obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5) def test_case_5(): pytorch_code = textwrap.dedent( - generate_torch_code( + generate_lr_scheduler_test_code( "torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3)" ) ) @@ -61,7 +67,7 @@ def test_case_5(): def test_case_6(): pytorch_code = textwrap.dedent( - generate_torch_code( + generate_lr_scheduler_test_code( "torch.optim.lr_scheduler.LinearLR(start_factor=0.05, end_factor=1.0, total_iters=3, optimizer=sgd)" ) ) @@ -70,7 +76,7 @@ def test_case_6(): def test_case_7(): pytorch_code = textwrap.dedent( - generate_torch_code( + generate_lr_scheduler_test_code( "torch.optim.lr_scheduler.LinearLR(sgd, 0.05, 1.0, 3, -1, False)" ) ) @@ -79,7 +85,7 @@ def test_case_7(): def test_case_8(): pytorch_code = textwrap.dedent( - generate_torch_code( + generate_lr_scheduler_test_code( "torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3, last_epoch=-1, verbose=False)" ) ) @@ -88,7 +94,7 @@ def test_case_8(): def test_case_9(): pytorch_code = textwrap.dedent( - generate_torch_code( + generate_lr_scheduler_test_code( [ "torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3, last_epoch=-1, verbose=False)", "torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3, last_epoch=scheduler_1.last_epoch, verbose=False)", @@ -100,6 +106,6 @@ def test_case_9(): def test_case_10(): pytorch_code = textwrap.dedent( - generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd)") + generate_lr_scheduler_test_code("torch.optim.lr_scheduler.LinearLR(sgd)") ) obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5) diff --git a/tests/test_signbit.py b/tests/test_signbit.py index 3ecc5aee4..bb87afebc 100644 --- a/tests/test_signbit.py +++ b/tests/test_signbit.py @@ -23,7 +23,7 @@ def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32') + x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32) result = torch.signbit(x) """ ) @@ -34,7 +34,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32') + x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32) result = torch.signbit(input=x) """ ) @@ -45,8 +45,8 @@ def test_case_3(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32') - out = torch.tensor([]) + x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32) + out = torch.tensor([], dtype=torch.bool) result = torch.signbit(out=out, input=x) """ ) @@ -57,8 +57,8 @@ def test_case_4(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32') - out = torch.tensor([]) + x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32) + out = torch.tensor([], dtype=torch.bool) result = torch.signbit(input=x, out=out) """ ) @@ -69,8 +69,8 @@ def test_case_5(): pytorch_code = textwrap.dedent( """ import torch - x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32') - out = torch.tensor([]) + x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32) + out = torch.tensor([], dtype=torch.bool) result = torch.signbit(x, out=out) """ )