File size: 2,793 Bytes
a69d385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from __future__ import division
from __future__ import print_function

import argparse
import time

import torch
from spatial_correlation_sampler import SpatialCorrelationSampler
from tqdm import trange

TIME_SCALES = {'s': 1, 'ms': 1000, 'us': 1000000}

parser = argparse.ArgumentParser()
parser.add_argument('backend', choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('-b', '--batch-size', type=int, default=16)
parser.add_argument('-k', '--kernel-size', type=int, default=3)
parser.add_argument('--patch', type=int, default=3)
parser.add_argument('--patch_dilation', type=int, default=2)
parser.add_argument('-c', '--channel', type=int, default=64)
parser.add_argument('--height', type=int, default=100)
parser.add_argument('-w', '--width', type=int, default=100)
parser.add_argument('-s', '--stride', type=int, default=2)
parser.add_argument('-p', '--pad', type=int, default=1)
parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us')
parser.add_argument('-r', '--runs', type=int, default=100)
parser.add_argument('--dilation', type=int, default=2)
parser.add_argument('-d', '--dtype', choices=['half', 'float', 'double'])

args = parser.parse_args()

device = torch.device(args.backend)

if args.dtype == 'half':
    dtype = torch.float16
elif args.dtype == 'float':
    dtype = torch.float32
else:
    dtype = torch.float64


input1 = torch.randn(args.batch_size,
                     args.channel,
                     args.height,
                     args.width,
                     dtype=dtype,
                     device=device,
                     requires_grad=True)
input2 = torch.randn_like(input1)

correlation_sampler = SpatialCorrelationSampler(
    args.kernel_size,
    args.patch,
    args.stride,
    args.pad,
    args.dilation,
    args.patch_dilation)

# Force CUDA initialization
output = correlation_sampler(input1, input2)
print(output.size())
output.mean().backward()
forward_min = float('inf')
forward_time = 0
backward_min = float('inf')
backward_time = 0
for _ in trange(args.runs):
    correlation_sampler.zero_grad()

    start = time.time()
    output = correlation_sampler(input1, input2)
    elapsed = time.time() - start
    forward_min = min(forward_min, elapsed)
    forward_time += elapsed
    output = output.mean()

    start = time.time()
    (output.mean()).backward()
    elapsed = time.time() - start
    backward_min = min(backward_min, elapsed)
    backward_time += elapsed

scale = TIME_SCALES[args.scale]
forward_min *= scale
backward_min *= scale
forward_average = forward_time / args.runs * scale
backward_average = backward_time / args.runs * scale

print('Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}'.format(
    forward_min, forward_average, backward_min, backward_average,
    args.scale))