Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Dec 21, 2023
1 parent f4b789e commit 8252a5a
Show file tree
Hide file tree
Showing 12 changed files with 52 additions and 49 deletions.
2 changes: 1 addition & 1 deletion tests/test_Tensor_index_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_case_3():
indices = torch.tensor([0, 1])
dim = 0
value = -1
result = torch.eye(3, 4).index_fill(dim, indices, value)
result = torch.eye(3, 4).index_fill(index=indices, dim=dim, value=value)
"""
)
obj.run(pytorch_code, ["result"])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_Tensor_index_fill_.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_case_3():
indices = torch.tensor([0, 1])
dim = 0
value = -1
result = x.index_fill_(dim, indices, value)
result = x.index_fill_(index=indices, dim=dim, value=value)
"""
)
obj.run(pytorch_code, ["result", "x"])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_Tensor_masked_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_case_5():
import torch
a = torch.Tensor([[1.0,0.2], [0.3,0.4]])
b = torch.Tensor([[1,0], [1,1]])
result = a.masked_fill(value=0.1, mask=b==1)
result = a.masked_fill(b==1, 0.1)
"""
)
obj.run(pytorch_code, ["result"])
2 changes: 1 addition & 1 deletion tests/test_Tensor_masked_fill_.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_case_5():
import torch
a = torch.Tensor([[1.0,0.2], [0.3,0.4]])
b = torch.Tensor([[1,0], [1,1]])
result = a.masked_fill_(value=0.1, mask=b==1)
result = a.masked_fill_(b==1, 0.1)
"""
)
obj.run(pytorch_code, ["result", "a"])
4 changes: 2 additions & 2 deletions tests/test_Tensor_signbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
result = x.signbit()
"""
)
Expand All @@ -34,7 +34,7 @@ def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float64')
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float64)
result = x.signbit()
"""
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_case_1():
"""
import torch
x = torch.tensor([1, 2, 3], dtype=torch.int32)
result = torch.combinations(input=x, r=2)
result = torch.combinations(input=x)
"""
)
obj.run(pytorch_code, ["result"])
Expand All @@ -36,7 +36,7 @@ def test_case_2():
"""
import torch
x = torch.tensor([1, 2, 3], dtype=torch.int32)
result = torch.combinations(x, r=2)
result = torch.combinations(x)
"""
)
obj.run(pytorch_code, ["result"])
Expand Down
4 changes: 2 additions & 2 deletions tests/test_histogramdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_case_1():
x = torch.tensor([[0., 1.], [1., 0.], [2.,0.], [2., 2.]])
bins = [3,3]
weights = torch.tensor([1., 2., 4., 8.])
result = torch.histogramdd(x, bins=bins, weight=weights)
result = torch.histogramdd(x, bins=bins)
"""
)
obj.run(pytorch_code, ["result"])
Expand All @@ -39,7 +39,7 @@ def test_case_2():
x = torch.tensor([[0., 1.], [1., 0.], [2.,0.], [2., 2.]])
bins = [3,3]
weights = torch.tensor([1., 2., 4., 8.])
result = torch.histogramdd(input=x, bins=bins, weight=weights)
result = torch.histogramdd(input=x, bins=bins)
"""
)
obj.run(pytorch_code, ["result"])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hypot.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_case_3():
a = torch.tensor([1., 2, 3])
b = torch.tensor([4., 5, 6])
out = torch.tensor([4., 5, 6])
result = torch.hypot(input=a, other=b, out=out)
result = torch.hypot(other=b, input=a, out=out)
"""
)
obj.run(pytorch_code, ["out"])
Expand Down
4 changes: 2 additions & 2 deletions tests/test_nn_Module_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def test_case_1():
obj.run(pytorch_code, ["result"])


# Will match torch.nn.Module, the named parameter "dst_type" cannot be resolved.
def _test_case_2():
# Will match torch.Tensor.type to resolve "dst_type" parameter.
def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
Expand Down
31 changes: 14 additions & 17 deletions tests/test_optim_lr_scheduler_CosineAnnealingWarmRestarts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import textwrap

from apibase import APIBase
from lr_scheduler_helper import generate_torch_code
from lr_scheduler_helper import generate_lr_scheduler_test_code

obj = APIBase("torch.optim.lr_scheduler.CosineAnnealingWarmRestarts")


def test_case_1():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, 10)"
)
)
Expand All @@ -31,58 +31,55 @@ def test_case_1():

def test_case_2():
pytorch_code = textwrap.dedent(
generate_torch_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, T_0=10)"
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, T_0=10, T_mult=1)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_3():
pytorch_code = textwrap.dedent(
generate_torch_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10)"
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(T_mult=1, optimizer=sgd, T_0=10)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_4():
pytorch_code = textwrap.dedent(
generate_torch_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.0, last_epoch=-1, verbose=True)"
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, T_mult=1, eta_min=0.0, last_epoch=-1, verbose=True)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_5():
pytorch_code = textwrap.dedent(
generate_torch_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.05, verbose=True)"
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, T_mult=1, eta_min=0.05, verbose=True)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_6():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(sgd, 10, 1, 1.0, -1, False)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


# reference: https://www.paddlepaddle.org.cn/documentation/docs/en/api/paddle/optimizer/lr/CosineAnnealingDecay_en.html
# note: paddle not support restart
# paddle result has diff with pytorch result
def test_case_7():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
[
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.0, last_epoch=-1, verbose=False)",
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, eta_min=0.0, last_epoch=scheduler_1.last_epoch, verbose=False)",
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, T_mult=1, eta_min=0.0, last_epoch=-1, verbose=False)",
"torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=sgd, T_0=10, T_mult=2, eta_min=0.0, last_epoch=scheduler_1.last_epoch, verbose=False)",
]
)
)
Expand Down
28 changes: 17 additions & 11 deletions tests/test_optim_lr_scheduler_LinearLR.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@
import textwrap

from apibase import APIBase
from lr_scheduler_helper import generate_torch_code
from lr_scheduler_helper import generate_lr_scheduler_test_code

obj = APIBase("torch.optim.lr_scheduler.LinearLR")


def test_case_1():
pytorch_code = textwrap.dedent(
generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd, verbose=True)")
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(sgd, verbose=True)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_2():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(sgd, start_factor=0.05, end_factor=1.0)"
)
)
Expand All @@ -38,21 +40,25 @@ def test_case_2():

def test_case_3():
pytorch_code = textwrap.dedent(
generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd, total_iters=3)")
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(sgd, total_iters=3)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_4():
pytorch_code = textwrap.dedent(
generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd, 0.05, 1)")
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(sgd, 0.05, 1)"
)
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)


def test_case_5():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3)"
)
)
Expand All @@ -61,7 +67,7 @@ def test_case_5():

def test_case_6():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(start_factor=0.05, end_factor=1.0, total_iters=3, optimizer=sgd)"
)
)
Expand All @@ -70,7 +76,7 @@ def test_case_6():

def test_case_7():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(sgd, 0.05, 1.0, 3, -1, False)"
)
)
Expand All @@ -79,7 +85,7 @@ def test_case_7():

def test_case_8():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
"torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3, last_epoch=-1, verbose=False)"
)
)
Expand All @@ -88,7 +94,7 @@ def test_case_8():

def test_case_9():
pytorch_code = textwrap.dedent(
generate_torch_code(
generate_lr_scheduler_test_code(
[
"torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3, last_epoch=-1, verbose=False)",
"torch.optim.lr_scheduler.LinearLR(optimizer=sgd, start_factor=0.05, end_factor=1.0, total_iters=3, last_epoch=scheduler_1.last_epoch, verbose=False)",
Expand All @@ -100,6 +106,6 @@ def test_case_9():

def test_case_10():
pytorch_code = textwrap.dedent(
generate_torch_code("torch.optim.lr_scheduler.LinearLR(sgd)")
generate_lr_scheduler_test_code("torch.optim.lr_scheduler.LinearLR(sgd)")
)
obj.run(pytorch_code, ["result1", "result2"], rtol=1.0e-5)
16 changes: 8 additions & 8 deletions tests/test_signbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
result = torch.signbit(x)
"""
)
Expand All @@ -34,7 +34,7 @@ def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
result = torch.signbit(input=x)
"""
)
Expand All @@ -45,8 +45,8 @@ def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
out = torch.tensor([])
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
out = torch.tensor([], dtype=torch.bool)
result = torch.signbit(out=out, input=x)
"""
)
Expand All @@ -57,8 +57,8 @@ def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
out = torch.tensor([])
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
out = torch.tensor([], dtype=torch.bool)
result = torch.signbit(input=x, out=out)
"""
)
Expand All @@ -69,8 +69,8 @@ def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype='float32')
out = torch.tensor([])
x = torch.tensor([-0., 1.1, -2.1, 0., 2.5], dtype=torch.float32)
out = torch.tensor([], dtype=torch.bool)
result = torch.signbit(x, out=out)
"""
)
Expand Down

0 comments on commit 8252a5a

Please sign in to comment.