forked from ClementPinard/Pytorch-Correlation-extension
-
Notifications
You must be signed in to change notification settings - Fork 0
/
grad_check.py
47 lines (40 loc) · 1.85 KB
/
grad_check.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import argparse
import torch
# torch.set_printoptions(precision=1, threshold=10000)
from torch.autograd import gradcheck
from spatial_correlation_sampler import SpatialCorrelationSampler
parser = argparse.ArgumentParser()
parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('-b', '--batch-size', type=int, default=2)
parser.add_argument('-k', '--kernel-size', type=int, default=3)
parser.add_argument('--patch', type=int, default=3)
parser.add_argument('--patch_dilation', type=int, default=2)
parser.add_argument('-c', '--channel', type=int, default=2)
parser.add_argument('--height', type=int, default=10)
parser.add_argument('-w', '--width', type=int, default=10)
parser.add_argument('-s', '--stride', type=int, default=2)
parser.add_argument('-p', '--pad', type=int, default=1)
parser.add_argument('-d', '--dilation', type=int, default=2)
args = parser.parse_args()
input1 = torch.randn(args.batch_size,
args.channel,
args.height,
args.width,
dtype=torch.float64,
device=torch.device(args.backend))
input2 = torch.randn(args.batch_size,
args.channel,
args.height,
args.width,
dtype=torch.float64,
device=torch.device(args.backend))
input1.requires_grad = True
input2.requires_grad = True
correlation_sampler = SpatialCorrelationSampler(args.kernel_size,
args.patch,
args.stride,
args.pad,
args.dilation,
args.patch_dilation)
if gradcheck(correlation_sampler, [input1, input2]):
print('Ok')