File size: 1,895 Bytes
e5f748f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import argparse
import torch
# torch.set_printoptions(precision=1, threshold=10000)
from torch.autograd import gradcheck
from spatial_correlation_sampler import SpatialCorrelationSampler

parser = argparse.ArgumentParser()
parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('-b', '--batch-size', type=int, default=2)
parser.add_argument('-k', '--kernel-size', type=int, default=3)
parser.add_argument('--patch', type=int, default=3)
parser.add_argument('--patch_dilation', type=int, default=2)
parser.add_argument('-c', '--channel', type=int, default=2)
parser.add_argument('--height', type=int, default=10)
parser.add_argument('-w', '--width', type=int, default=10)
parser.add_argument('-s', '--stride', type=int, default=2)
parser.add_argument('-p', '--pad', type=int, default=1)
parser.add_argument('-d', '--dilation', type=int, default=2)

args = parser.parse_args()

input1 = torch.randn(args.batch_size,
                     args.channel,
                     args.height,
                     args.width,
                     dtype=torch.float64,
                     device=torch.device(args.backend))
input2 = torch.randn(args.batch_size,
                     args.channel,
                     args.height,
                     args.width,
                     dtype=torch.float64,
                     device=torch.device(args.backend))

input1.requires_grad = True
input2.requires_grad = True

correlation_sampler = SpatialCorrelationSampler(args.kernel_size,
                                                args.patch,
                                                args.stride,
                                                args.pad,
                                                args.dilation,
                                                args.patch_dilation)


if gradcheck(correlation_sampler, [input1, input2]):
    print('Ok')