from ..ssd.vgg_ssd import create_vgg_ssd import torch import tempfile def test_create_vgg_ssd(): for num_classes in [2, 10, 21, 100]: _ = create_vgg_ssd(num_classes) def test_forward(): for num_classes in [2]: net = create_vgg_ssd(num_classes) net.init() net.eval() x = torch.randn(2, 3, 300, 300) confidences, locations = net.forward(x) assert confidences.size() == torch.Size([2, 8732, num_classes]) assert locations.size() == torch.Size([2, 8732, 4]) assert confidences.nonzero().size(0) != 0 assert locations.nonzero().size(0) != 0 def test_save_model(): net = create_vgg_ssd(10) net.init() with tempfile.TemporaryFile() as f: net.save(f) def test_save_load_model_consistency(): net = create_vgg_ssd(20) net.init() model_path = tempfile.NamedTemporaryFile().name net.save(model_path) net_copy = create_vgg_ssd(20) net_copy.load(model_path) net.eval() net_copy.eval() for _ in range(1): x = torch.randn(1, 3, 300, 300) confidences1, locations1 = net.forward(x) confidences2, locations2 = net_copy.forward(x) assert (confidences1 == confidences2).long().sum() == confidences2.numel() assert (locations1 == locations2).long().sum() == locations2.numel()