fa0311's picture
Upload 81 files
b559e06 verified
from ..ssd.vgg_ssd import create_vgg_ssd
import torch
import tempfile
def test_create_vgg_ssd():
for num_classes in [2, 10, 21, 100]:
_ = create_vgg_ssd(num_classes)
def test_forward():
for num_classes in [2]:
net = create_vgg_ssd(num_classes)
net.init()
net.eval()
x = torch.randn(2, 3, 300, 300)
confidences, locations = net.forward(x)
assert confidences.size() == torch.Size([2, 8732, num_classes])
assert locations.size() == torch.Size([2, 8732, 4])
assert confidences.nonzero().size(0) != 0
assert locations.nonzero().size(0) != 0
def test_save_model():
net = create_vgg_ssd(10)
net.init()
with tempfile.TemporaryFile() as f:
net.save(f)
def test_save_load_model_consistency():
net = create_vgg_ssd(20)
net.init()
model_path = tempfile.NamedTemporaryFile().name
net.save(model_path)
net_copy = create_vgg_ssd(20)
net_copy.load(model_path)
net.eval()
net_copy.eval()
for _ in range(1):
x = torch.randn(1, 3, 300, 300)
confidences1, locations1 = net.forward(x)
confidences2, locations2 = net_copy.forward(x)
assert (confidences1 == confidences2).long().sum() == confidences2.numel()
assert (locations1 == locations2).long().sum() == locations2.numel()