File size: 1,398 Bytes
b559e06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from ..ssd.vgg_ssd import create_vgg_ssd

import torch
import tempfile


def test_create_vgg_ssd():
    for num_classes in [2, 10, 21, 100]:
        _ = create_vgg_ssd(num_classes)


def test_forward():
    for num_classes in [2]:
        net = create_vgg_ssd(num_classes)
        net.init()
        net.eval()
        x = torch.randn(2, 3, 300, 300)
        confidences, locations = net.forward(x)
        assert confidences.size() == torch.Size([2, 8732, num_classes])
        assert locations.size() == torch.Size([2, 8732, 4])
        assert confidences.nonzero().size(0) != 0
        assert locations.nonzero().size(0) != 0


def test_save_model():
    net = create_vgg_ssd(10)
    net.init()
    with tempfile.TemporaryFile() as f:
        net.save(f)


def test_save_load_model_consistency():
    net = create_vgg_ssd(20)
    net.init()
    model_path = tempfile.NamedTemporaryFile().name
    net.save(model_path)
    net_copy = create_vgg_ssd(20)
    net_copy.load(model_path)

    net.eval()
    net_copy.eval()

    for _ in range(1):
        x = torch.randn(1, 3, 300, 300)
        confidences1, locations1 = net.forward(x)
        confidences2, locations2 = net_copy.forward(x)
        assert (confidences1 == confidences2).long().sum() == confidences2.numel()
        assert (locations1 == locations2).long().sum() == locations2.numel()