Upload 847 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- annotator/__pycache__/util.cpython-310.pyc +0 -0
- annotator/__pycache__/util.cpython-38.pyc +0 -0
- annotator/canny/__init__.py +6 -0
- annotator/canny/__pycache__/__init__.cpython-38.pyc +0 -0
- annotator/ckpts/ckpts.txt +1 -0
- annotator/hed/__init__.py +80 -0
- annotator/hed/__pycache__/__init__.cpython-38.pyc +0 -0
- annotator/lineart/LICENSE +21 -0
- annotator/lineart/__init__.py +124 -0
- annotator/lineart/__pycache__/__init__.cpython-38.pyc +0 -0
- annotator/lineart_anime/LICENSE +21 -0
- annotator/lineart_anime/__init__.py +150 -0
- annotator/lineart_anime/__pycache__/__init__.cpython-38.pyc +0 -0
- annotator/midas/LICENSE +21 -0
- annotator/midas/__init__.py +31 -0
- annotator/midas/api.py +169 -0
- annotator/midas/midas/__init__.py +0 -0
- annotator/midas/midas/base_model.py +16 -0
- annotator/midas/midas/blocks.py +342 -0
- annotator/midas/midas/dpt_depth.py +109 -0
- annotator/midas/midas/midas_net.py +76 -0
- annotator/midas/midas/midas_net_custom.py +128 -0
- annotator/midas/midas/transforms.py +234 -0
- annotator/midas/midas/vit.py +491 -0
- annotator/midas/utils.py +189 -0
- annotator/mlsd/LICENSE +201 -0
- annotator/mlsd/__init__.py +43 -0
- annotator/mlsd/__pycache__/__init__.cpython-38.pyc +0 -0
- annotator/mlsd/models/mbv2_mlsd_large.py +292 -0
- annotator/mlsd/models/mbv2_mlsd_tiny.py +275 -0
- annotator/mlsd/utils.py +580 -0
- annotator/normalbae/LICENSE +21 -0
- annotator/normalbae/__init__.py +55 -0
- annotator/normalbae/models/NNET.py +22 -0
- annotator/normalbae/models/baseline.py +85 -0
- annotator/normalbae/models/submodules/decoder.py +202 -0
- annotator/normalbae/models/submodules/efficientnet_repo/BENCHMARK.md +555 -0
- annotator/normalbae/models/submodules/efficientnet_repo/LICENSE +201 -0
- annotator/normalbae/models/submodules/efficientnet_repo/README.md +323 -0
- annotator/normalbae/models/submodules/efficientnet_repo/caffe2_benchmark.py +65 -0
- annotator/normalbae/models/submodules/efficientnet_repo/caffe2_validate.py +138 -0
- annotator/normalbae/models/submodules/efficientnet_repo/data/__init__.py +3 -0
- annotator/normalbae/models/submodules/efficientnet_repo/data/dataset.py +91 -0
- annotator/normalbae/models/submodules/efficientnet_repo/data/loader.py +108 -0
- annotator/normalbae/models/submodules/efficientnet_repo/data/tf_preprocessing.py +234 -0
- annotator/normalbae/models/submodules/efficientnet_repo/data/transforms.py +150 -0
- annotator/normalbae/models/submodules/efficientnet_repo/geffnet/__init__.py +5 -0
- annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/__init__.py +137 -0
- annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/activations.py +102 -0
- annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/activations_jit.py +79 -0
annotator/__pycache__/util.cpython-310.pyc
ADDED
Binary file (3.09 kB). View file
|
|
annotator/__pycache__/util.cpython-38.pyc
ADDED
Binary file (3.06 kB). View file
|
|
annotator/canny/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
|
3 |
+
|
4 |
+
class CannyDetector:
|
5 |
+
def __call__(self, img, low_threshold, high_threshold):
|
6 |
+
return cv2.Canny(img, low_threshold, high_threshold)
|
annotator/canny/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (517 Bytes). View file
|
|
annotator/ckpts/ckpts.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Weights here.
|
annotator/hed/__init__.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
|
2 |
+
# Please use this implementation in your products
|
3 |
+
# This implementation may produce slightly different results from Saining Xie's official implementations,
|
4 |
+
# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
|
5 |
+
# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
|
6 |
+
# and in this way it works better for gradio's RGB protocol
|
7 |
+
|
8 |
+
import os
|
9 |
+
import cv2
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from einops import rearrange
|
14 |
+
from annotator.util import annotator_ckpts_path, safe_step
|
15 |
+
|
16 |
+
|
17 |
+
class DoubleConvBlock(torch.nn.Module):
|
18 |
+
def __init__(self, input_channel, output_channel, layer_number):
|
19 |
+
super().__init__()
|
20 |
+
self.convs = torch.nn.Sequential()
|
21 |
+
self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
22 |
+
for i in range(1, layer_number):
|
23 |
+
self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
24 |
+
self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
|
25 |
+
|
26 |
+
def __call__(self, x, down_sampling=False):
|
27 |
+
h = x
|
28 |
+
if down_sampling:
|
29 |
+
h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
|
30 |
+
for conv in self.convs:
|
31 |
+
h = conv(h)
|
32 |
+
h = torch.nn.functional.relu(h)
|
33 |
+
return h, self.projection(h)
|
34 |
+
|
35 |
+
|
36 |
+
class ControlNetHED_Apache2(torch.nn.Module):
|
37 |
+
def __init__(self):
|
38 |
+
super().__init__()
|
39 |
+
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
|
40 |
+
self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
|
41 |
+
self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
|
42 |
+
self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
|
43 |
+
self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
|
44 |
+
self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
|
45 |
+
|
46 |
+
def __call__(self, x):
|
47 |
+
h = x - self.norm
|
48 |
+
h, projection1 = self.block1(h)
|
49 |
+
h, projection2 = self.block2(h, down_sampling=True)
|
50 |
+
h, projection3 = self.block3(h, down_sampling=True)
|
51 |
+
h, projection4 = self.block4(h, down_sampling=True)
|
52 |
+
h, projection5 = self.block5(h, down_sampling=True)
|
53 |
+
return projection1, projection2, projection3, projection4, projection5
|
54 |
+
|
55 |
+
|
56 |
+
class HEDdetector:
|
57 |
+
def __init__(self):
|
58 |
+
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/ControlNetHED.pth"
|
59 |
+
modelpath = os.path.join(annotator_ckpts_path, "ControlNetHED.pth")
|
60 |
+
if not os.path.exists(modelpath):
|
61 |
+
from basicsr.utils.download_util import load_file_from_url
|
62 |
+
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
63 |
+
self.netNetwork = ControlNetHED_Apache2().float().cuda().eval()
|
64 |
+
self.netNetwork.load_state_dict(torch.load(modelpath))
|
65 |
+
|
66 |
+
def __call__(self, input_image, safe=False):
|
67 |
+
assert input_image.ndim == 3
|
68 |
+
H, W, C = input_image.shape
|
69 |
+
with torch.no_grad():
|
70 |
+
image_hed = torch.from_numpy(input_image.copy()).float().cuda()
|
71 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
72 |
+
edges = self.netNetwork(image_hed)
|
73 |
+
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
|
74 |
+
edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
|
75 |
+
edges = np.stack(edges, axis=2)
|
76 |
+
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
|
77 |
+
if safe:
|
78 |
+
edge = safe_step(edge)
|
79 |
+
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
80 |
+
return edge
|
annotator/hed/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (3.88 kB). View file
|
|
annotator/lineart/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Caroline Chan
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
annotator/lineart/__init__.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# From https://github.com/carolineec/informative-drawings
|
2 |
+
# MIT License
|
3 |
+
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch.nn as nn
|
10 |
+
from einops import rearrange
|
11 |
+
from annotator.util import annotator_ckpts_path
|
12 |
+
|
13 |
+
|
14 |
+
norm_layer = nn.InstanceNorm2d
|
15 |
+
|
16 |
+
|
17 |
+
class ResidualBlock(nn.Module):
|
18 |
+
def __init__(self, in_features):
|
19 |
+
super(ResidualBlock, self).__init__()
|
20 |
+
|
21 |
+
conv_block = [ nn.ReflectionPad2d(1),
|
22 |
+
nn.Conv2d(in_features, in_features, 3),
|
23 |
+
norm_layer(in_features),
|
24 |
+
nn.ReLU(inplace=True),
|
25 |
+
nn.ReflectionPad2d(1),
|
26 |
+
nn.Conv2d(in_features, in_features, 3),
|
27 |
+
norm_layer(in_features)
|
28 |
+
]
|
29 |
+
|
30 |
+
self.conv_block = nn.Sequential(*conv_block)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
return x + self.conv_block(x)
|
34 |
+
|
35 |
+
|
36 |
+
class Generator(nn.Module):
|
37 |
+
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
|
38 |
+
super(Generator, self).__init__()
|
39 |
+
|
40 |
+
# Initial convolution block
|
41 |
+
model0 = [ nn.ReflectionPad2d(3),
|
42 |
+
nn.Conv2d(input_nc, 64, 7),
|
43 |
+
norm_layer(64),
|
44 |
+
nn.ReLU(inplace=True) ]
|
45 |
+
self.model0 = nn.Sequential(*model0)
|
46 |
+
|
47 |
+
# Downsampling
|
48 |
+
model1 = []
|
49 |
+
in_features = 64
|
50 |
+
out_features = in_features*2
|
51 |
+
for _ in range(2):
|
52 |
+
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
53 |
+
norm_layer(out_features),
|
54 |
+
nn.ReLU(inplace=True) ]
|
55 |
+
in_features = out_features
|
56 |
+
out_features = in_features*2
|
57 |
+
self.model1 = nn.Sequential(*model1)
|
58 |
+
|
59 |
+
model2 = []
|
60 |
+
# Residual blocks
|
61 |
+
for _ in range(n_residual_blocks):
|
62 |
+
model2 += [ResidualBlock(in_features)]
|
63 |
+
self.model2 = nn.Sequential(*model2)
|
64 |
+
|
65 |
+
# Upsampling
|
66 |
+
model3 = []
|
67 |
+
out_features = in_features//2
|
68 |
+
for _ in range(2):
|
69 |
+
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
70 |
+
norm_layer(out_features),
|
71 |
+
nn.ReLU(inplace=True) ]
|
72 |
+
in_features = out_features
|
73 |
+
out_features = in_features//2
|
74 |
+
self.model3 = nn.Sequential(*model3)
|
75 |
+
|
76 |
+
# Output layer
|
77 |
+
model4 = [ nn.ReflectionPad2d(3),
|
78 |
+
nn.Conv2d(64, output_nc, 7)]
|
79 |
+
if sigmoid:
|
80 |
+
model4 += [nn.Sigmoid()]
|
81 |
+
|
82 |
+
self.model4 = nn.Sequential(*model4)
|
83 |
+
|
84 |
+
def forward(self, x, cond=None):
|
85 |
+
out = self.model0(x)
|
86 |
+
out = self.model1(out)
|
87 |
+
out = self.model2(out)
|
88 |
+
out = self.model3(out)
|
89 |
+
out = self.model4(out)
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class LineartDetector:
|
95 |
+
def __init__(self):
|
96 |
+
self.model = self.load_model('sk_model.pth')
|
97 |
+
self.model_coarse = self.load_model('sk_model2.pth')
|
98 |
+
|
99 |
+
def load_model(self, name):
|
100 |
+
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/" + name
|
101 |
+
modelpath = os.path.join(annotator_ckpts_path, name)
|
102 |
+
if not os.path.exists(modelpath):
|
103 |
+
from basicsr.utils.download_util import load_file_from_url
|
104 |
+
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
105 |
+
model = Generator(3, 1, 3)
|
106 |
+
model.load_state_dict(torch.load(modelpath, map_location=torch.device('cpu')))
|
107 |
+
model.eval()
|
108 |
+
model = model.cuda()
|
109 |
+
return model
|
110 |
+
|
111 |
+
def __call__(self, input_image, coarse):
|
112 |
+
model = self.model_coarse if coarse else self.model
|
113 |
+
assert input_image.ndim == 3
|
114 |
+
image = input_image
|
115 |
+
with torch.no_grad():
|
116 |
+
image = torch.from_numpy(image).float().cuda()
|
117 |
+
image = image / 255.0
|
118 |
+
image = rearrange(image, 'h w c -> 1 c h w')
|
119 |
+
line = model(image)[0][0]
|
120 |
+
|
121 |
+
line = line.cpu().numpy()
|
122 |
+
line = (line * 255.0).clip(0, 255).astype(np.uint8)
|
123 |
+
|
124 |
+
return line
|
annotator/lineart/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (3.66 kB). View file
|
|
annotator/lineart_anime/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Caroline Chan
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
annotator/lineart_anime/__init__.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Anime2sketch
|
2 |
+
# https://github.com/Mukosame/Anime2Sketch
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import functools
|
8 |
+
|
9 |
+
import os
|
10 |
+
import cv2
|
11 |
+
from einops import rearrange
|
12 |
+
from annotator.util import annotator_ckpts_path
|
13 |
+
|
14 |
+
|
15 |
+
class UnetGenerator(nn.Module):
|
16 |
+
"""Create a Unet-based generator"""
|
17 |
+
|
18 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
19 |
+
"""Construct a Unet generator
|
20 |
+
Parameters:
|
21 |
+
input_nc (int) -- the number of channels in input images
|
22 |
+
output_nc (int) -- the number of channels in output images
|
23 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
24 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
25 |
+
ngf (int) -- the number of filters in the last conv layer
|
26 |
+
norm_layer -- normalization layer
|
27 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
28 |
+
It is a recursive process.
|
29 |
+
"""
|
30 |
+
super(UnetGenerator, self).__init__()
|
31 |
+
# construct unet structure
|
32 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
|
33 |
+
for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
34 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
35 |
+
# gradually reduce the number of filters from ngf * 8 to ngf
|
36 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
37 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
38 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
39 |
+
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
|
40 |
+
|
41 |
+
def forward(self, input):
|
42 |
+
"""Standard forward"""
|
43 |
+
return self.model(input)
|
44 |
+
|
45 |
+
|
46 |
+
class UnetSkipConnectionBlock(nn.Module):
|
47 |
+
"""Defines the Unet submodule with skip connection.
|
48 |
+
X -------------------identity----------------------
|
49 |
+
|-- downsampling -- |submodule| -- upsampling --|
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
53 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
54 |
+
"""Construct a Unet submodule with skip connections.
|
55 |
+
Parameters:
|
56 |
+
outer_nc (int) -- the number of filters in the outer conv layer
|
57 |
+
inner_nc (int) -- the number of filters in the inner conv layer
|
58 |
+
input_nc (int) -- the number of channels in input images/features
|
59 |
+
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
60 |
+
outermost (bool) -- if this module is the outermost module
|
61 |
+
innermost (bool) -- if this module is the innermost module
|
62 |
+
norm_layer -- normalization layer
|
63 |
+
use_dropout (bool) -- if use dropout layers.
|
64 |
+
"""
|
65 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
66 |
+
self.outermost = outermost
|
67 |
+
if type(norm_layer) == functools.partial:
|
68 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
69 |
+
else:
|
70 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
71 |
+
if input_nc is None:
|
72 |
+
input_nc = outer_nc
|
73 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
74 |
+
stride=2, padding=1, bias=use_bias)
|
75 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
76 |
+
downnorm = norm_layer(inner_nc)
|
77 |
+
uprelu = nn.ReLU(True)
|
78 |
+
upnorm = norm_layer(outer_nc)
|
79 |
+
|
80 |
+
if outermost:
|
81 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
82 |
+
kernel_size=4, stride=2,
|
83 |
+
padding=1)
|
84 |
+
down = [downconv]
|
85 |
+
up = [uprelu, upconv, nn.Tanh()]
|
86 |
+
model = down + [submodule] + up
|
87 |
+
elif innermost:
|
88 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
89 |
+
kernel_size=4, stride=2,
|
90 |
+
padding=1, bias=use_bias)
|
91 |
+
down = [downrelu, downconv]
|
92 |
+
up = [uprelu, upconv, upnorm]
|
93 |
+
model = down + up
|
94 |
+
else:
|
95 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
96 |
+
kernel_size=4, stride=2,
|
97 |
+
padding=1, bias=use_bias)
|
98 |
+
down = [downrelu, downconv, downnorm]
|
99 |
+
up = [uprelu, upconv, upnorm]
|
100 |
+
|
101 |
+
if use_dropout:
|
102 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
103 |
+
else:
|
104 |
+
model = down + [submodule] + up
|
105 |
+
|
106 |
+
self.model = nn.Sequential(*model)
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
if self.outermost:
|
110 |
+
return self.model(x)
|
111 |
+
else: # add skip connections
|
112 |
+
return torch.cat([x, self.model(x)], 1)
|
113 |
+
|
114 |
+
|
115 |
+
class LineartAnimeDetector:
|
116 |
+
def __init__(self):
|
117 |
+
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/netG.pth"
|
118 |
+
modelpath = os.path.join(annotator_ckpts_path, "netG.pth")
|
119 |
+
if not os.path.exists(modelpath):
|
120 |
+
from basicsr.utils.download_util import load_file_from_url
|
121 |
+
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
122 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
123 |
+
net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
|
124 |
+
ckpt = torch.load(modelpath)
|
125 |
+
for key in list(ckpt.keys()):
|
126 |
+
if 'module.' in key:
|
127 |
+
ckpt[key.replace('module.', '')] = ckpt[key]
|
128 |
+
del ckpt[key]
|
129 |
+
net.load_state_dict(ckpt)
|
130 |
+
net = net.cuda()
|
131 |
+
net.eval()
|
132 |
+
self.model = net
|
133 |
+
|
134 |
+
def __call__(self, input_image):
|
135 |
+
H, W, C = input_image.shape
|
136 |
+
Hn = 256 * int(np.ceil(float(H) / 256.0))
|
137 |
+
Wn = 256 * int(np.ceil(float(W) / 256.0))
|
138 |
+
img = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_CUBIC)
|
139 |
+
with torch.no_grad():
|
140 |
+
image_feed = torch.from_numpy(img).float().cuda()
|
141 |
+
image_feed = image_feed / 127.5 - 1.0
|
142 |
+
image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
|
143 |
+
|
144 |
+
line = self.model(image_feed)[0, 0] * 127.5 + 127.5
|
145 |
+
line = line.cpu().numpy()
|
146 |
+
|
147 |
+
line = cv2.resize(line, (W, H), interpolation=cv2.INTER_CUBIC)
|
148 |
+
line = line.clip(0, 255).astype(np.uint8)
|
149 |
+
return line
|
150 |
+
|
annotator/lineart_anime/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (5.69 kB). View file
|
|
annotator/midas/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
annotator/midas/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Midas Depth Estimation
|
2 |
+
# From https://github.com/isl-org/MiDaS
|
3 |
+
# MIT LICENSE
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from einops import rearrange
|
10 |
+
from .api import MiDaSInference
|
11 |
+
|
12 |
+
|
13 |
+
class MidasDetector:
|
14 |
+
def __init__(self):
|
15 |
+
self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
|
16 |
+
|
17 |
+
def __call__(self, input_image):
|
18 |
+
assert input_image.ndim == 3
|
19 |
+
image_depth = input_image
|
20 |
+
with torch.no_grad():
|
21 |
+
image_depth = torch.from_numpy(image_depth).float().cuda()
|
22 |
+
image_depth = image_depth / 127.5 - 1.0
|
23 |
+
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
24 |
+
depth = self.model(image_depth)[0]
|
25 |
+
|
26 |
+
depth -= torch.min(depth)
|
27 |
+
depth /= torch.max(depth)
|
28 |
+
depth = depth.cpu().numpy()
|
29 |
+
depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
|
30 |
+
|
31 |
+
return depth_image
|
annotator/midas/api.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# based on https://github.com/isl-org/MiDaS
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision.transforms import Compose
|
8 |
+
|
9 |
+
from .midas.dpt_depth import DPTDepthModel
|
10 |
+
from .midas.midas_net import MidasNet
|
11 |
+
from .midas.midas_net_custom import MidasNet_small
|
12 |
+
from .midas.transforms import Resize, NormalizeImage, PrepareForNet
|
13 |
+
from annotator.util import annotator_ckpts_path
|
14 |
+
|
15 |
+
|
16 |
+
ISL_PATHS = {
|
17 |
+
"dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
|
18 |
+
"dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
|
19 |
+
"midas_v21": "",
|
20 |
+
"midas_v21_small": "",
|
21 |
+
}
|
22 |
+
|
23 |
+
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/dpt_hybrid-midas-501f0c75.pt"
|
24 |
+
|
25 |
+
|
26 |
+
def disabled_train(self, mode=True):
|
27 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
28 |
+
does not change anymore."""
|
29 |
+
return self
|
30 |
+
|
31 |
+
|
32 |
+
def load_midas_transform(model_type):
|
33 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
34 |
+
# load transform only
|
35 |
+
if model_type == "dpt_large": # DPT-Large
|
36 |
+
net_w, net_h = 384, 384
|
37 |
+
resize_mode = "minimal"
|
38 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
39 |
+
|
40 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
41 |
+
net_w, net_h = 384, 384
|
42 |
+
resize_mode = "minimal"
|
43 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
44 |
+
|
45 |
+
elif model_type == "midas_v21":
|
46 |
+
net_w, net_h = 384, 384
|
47 |
+
resize_mode = "upper_bound"
|
48 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
49 |
+
|
50 |
+
elif model_type == "midas_v21_small":
|
51 |
+
net_w, net_h = 256, 256
|
52 |
+
resize_mode = "upper_bound"
|
53 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
54 |
+
|
55 |
+
else:
|
56 |
+
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
57 |
+
|
58 |
+
transform = Compose(
|
59 |
+
[
|
60 |
+
Resize(
|
61 |
+
net_w,
|
62 |
+
net_h,
|
63 |
+
resize_target=None,
|
64 |
+
keep_aspect_ratio=True,
|
65 |
+
ensure_multiple_of=32,
|
66 |
+
resize_method=resize_mode,
|
67 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
68 |
+
),
|
69 |
+
normalization,
|
70 |
+
PrepareForNet(),
|
71 |
+
]
|
72 |
+
)
|
73 |
+
|
74 |
+
return transform
|
75 |
+
|
76 |
+
|
77 |
+
def load_model(model_type):
|
78 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
79 |
+
# load network
|
80 |
+
model_path = ISL_PATHS[model_type]
|
81 |
+
if model_type == "dpt_large": # DPT-Large
|
82 |
+
model = DPTDepthModel(
|
83 |
+
path=model_path,
|
84 |
+
backbone="vitl16_384",
|
85 |
+
non_negative=True,
|
86 |
+
)
|
87 |
+
net_w, net_h = 384, 384
|
88 |
+
resize_mode = "minimal"
|
89 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
90 |
+
|
91 |
+
elif model_type == "dpt_hybrid": # DPT-Hybrid
|
92 |
+
if not os.path.exists(model_path):
|
93 |
+
from basicsr.utils.download_util import load_file_from_url
|
94 |
+
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
95 |
+
|
96 |
+
model = DPTDepthModel(
|
97 |
+
path=model_path,
|
98 |
+
backbone="vitb_rn50_384",
|
99 |
+
non_negative=True,
|
100 |
+
)
|
101 |
+
net_w, net_h = 384, 384
|
102 |
+
resize_mode = "minimal"
|
103 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
104 |
+
|
105 |
+
elif model_type == "midas_v21":
|
106 |
+
model = MidasNet(model_path, non_negative=True)
|
107 |
+
net_w, net_h = 384, 384
|
108 |
+
resize_mode = "upper_bound"
|
109 |
+
normalization = NormalizeImage(
|
110 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
111 |
+
)
|
112 |
+
|
113 |
+
elif model_type == "midas_v21_small":
|
114 |
+
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
115 |
+
non_negative=True, blocks={'expand': True})
|
116 |
+
net_w, net_h = 256, 256
|
117 |
+
resize_mode = "upper_bound"
|
118 |
+
normalization = NormalizeImage(
|
119 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
120 |
+
)
|
121 |
+
|
122 |
+
else:
|
123 |
+
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
124 |
+
assert False
|
125 |
+
|
126 |
+
transform = Compose(
|
127 |
+
[
|
128 |
+
Resize(
|
129 |
+
net_w,
|
130 |
+
net_h,
|
131 |
+
resize_target=None,
|
132 |
+
keep_aspect_ratio=True,
|
133 |
+
ensure_multiple_of=32,
|
134 |
+
resize_method=resize_mode,
|
135 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
136 |
+
),
|
137 |
+
normalization,
|
138 |
+
PrepareForNet(),
|
139 |
+
]
|
140 |
+
)
|
141 |
+
|
142 |
+
return model.eval(), transform
|
143 |
+
|
144 |
+
|
145 |
+
class MiDaSInference(nn.Module):
|
146 |
+
MODEL_TYPES_TORCH_HUB = [
|
147 |
+
"DPT_Large",
|
148 |
+
"DPT_Hybrid",
|
149 |
+
"MiDaS_small"
|
150 |
+
]
|
151 |
+
MODEL_TYPES_ISL = [
|
152 |
+
"dpt_large",
|
153 |
+
"dpt_hybrid",
|
154 |
+
"midas_v21",
|
155 |
+
"midas_v21_small",
|
156 |
+
]
|
157 |
+
|
158 |
+
def __init__(self, model_type):
|
159 |
+
super().__init__()
|
160 |
+
assert (model_type in self.MODEL_TYPES_ISL)
|
161 |
+
model, _ = load_model(model_type)
|
162 |
+
self.model = model
|
163 |
+
self.model.train = disabled_train
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
with torch.no_grad():
|
167 |
+
prediction = self.model(x)
|
168 |
+
return prediction
|
169 |
+
|
annotator/midas/midas/__init__.py
ADDED
File without changes
|
annotator/midas/midas/base_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseModel(torch.nn.Module):
|
5 |
+
def load(self, path):
|
6 |
+
"""Load model from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
path (str): file path
|
10 |
+
"""
|
11 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
12 |
+
|
13 |
+
if "optimizer" in parameters:
|
14 |
+
parameters = parameters["model"]
|
15 |
+
|
16 |
+
self.load_state_dict(parameters)
|
annotator/midas/midas/blocks.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .vit import (
|
5 |
+
_make_pretrained_vitb_rn50_384,
|
6 |
+
_make_pretrained_vitl16_384,
|
7 |
+
_make_pretrained_vitb16_384,
|
8 |
+
forward_vit,
|
9 |
+
)
|
10 |
+
|
11 |
+
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
|
12 |
+
if backbone == "vitl16_384":
|
13 |
+
pretrained = _make_pretrained_vitl16_384(
|
14 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
15 |
+
)
|
16 |
+
scratch = _make_scratch(
|
17 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
18 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
19 |
+
elif backbone == "vitb_rn50_384":
|
20 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
21 |
+
use_pretrained,
|
22 |
+
hooks=hooks,
|
23 |
+
use_vit_only=use_vit_only,
|
24 |
+
use_readout=use_readout,
|
25 |
+
)
|
26 |
+
scratch = _make_scratch(
|
27 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
28 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
29 |
+
elif backbone == "vitb16_384":
|
30 |
+
pretrained = _make_pretrained_vitb16_384(
|
31 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
32 |
+
)
|
33 |
+
scratch = _make_scratch(
|
34 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
35 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
36 |
+
elif backbone == "resnext101_wsl":
|
37 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
38 |
+
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
39 |
+
elif backbone == "efficientnet_lite3":
|
40 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
41 |
+
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
42 |
+
else:
|
43 |
+
print(f"Backbone '{backbone}' not implemented")
|
44 |
+
assert False
|
45 |
+
|
46 |
+
return pretrained, scratch
|
47 |
+
|
48 |
+
|
49 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
50 |
+
scratch = nn.Module()
|
51 |
+
|
52 |
+
out_shape1 = out_shape
|
53 |
+
out_shape2 = out_shape
|
54 |
+
out_shape3 = out_shape
|
55 |
+
out_shape4 = out_shape
|
56 |
+
if expand==True:
|
57 |
+
out_shape1 = out_shape
|
58 |
+
out_shape2 = out_shape*2
|
59 |
+
out_shape3 = out_shape*4
|
60 |
+
out_shape4 = out_shape*8
|
61 |
+
|
62 |
+
scratch.layer1_rn = nn.Conv2d(
|
63 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
64 |
+
)
|
65 |
+
scratch.layer2_rn = nn.Conv2d(
|
66 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
67 |
+
)
|
68 |
+
scratch.layer3_rn = nn.Conv2d(
|
69 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
70 |
+
)
|
71 |
+
scratch.layer4_rn = nn.Conv2d(
|
72 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
73 |
+
)
|
74 |
+
|
75 |
+
return scratch
|
76 |
+
|
77 |
+
|
78 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
79 |
+
efficientnet = torch.hub.load(
|
80 |
+
"rwightman/gen-efficientnet-pytorch",
|
81 |
+
"tf_efficientnet_lite3",
|
82 |
+
pretrained=use_pretrained,
|
83 |
+
exportable=exportable
|
84 |
+
)
|
85 |
+
return _make_efficientnet_backbone(efficientnet)
|
86 |
+
|
87 |
+
|
88 |
+
def _make_efficientnet_backbone(effnet):
|
89 |
+
pretrained = nn.Module()
|
90 |
+
|
91 |
+
pretrained.layer1 = nn.Sequential(
|
92 |
+
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
93 |
+
)
|
94 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
95 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
96 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
97 |
+
|
98 |
+
return pretrained
|
99 |
+
|
100 |
+
|
101 |
+
def _make_resnet_backbone(resnet):
|
102 |
+
pretrained = nn.Module()
|
103 |
+
pretrained.layer1 = nn.Sequential(
|
104 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
105 |
+
)
|
106 |
+
|
107 |
+
pretrained.layer2 = resnet.layer2
|
108 |
+
pretrained.layer3 = resnet.layer3
|
109 |
+
pretrained.layer4 = resnet.layer4
|
110 |
+
|
111 |
+
return pretrained
|
112 |
+
|
113 |
+
|
114 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
115 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
116 |
+
return _make_resnet_backbone(resnet)
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
class Interpolate(nn.Module):
|
121 |
+
"""Interpolation module.
|
122 |
+
"""
|
123 |
+
|
124 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
125 |
+
"""Init.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
scale_factor (float): scaling
|
129 |
+
mode (str): interpolation mode
|
130 |
+
"""
|
131 |
+
super(Interpolate, self).__init__()
|
132 |
+
|
133 |
+
self.interp = nn.functional.interpolate
|
134 |
+
self.scale_factor = scale_factor
|
135 |
+
self.mode = mode
|
136 |
+
self.align_corners = align_corners
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
"""Forward pass.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x (tensor): input
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
tensor: interpolated data
|
146 |
+
"""
|
147 |
+
|
148 |
+
x = self.interp(
|
149 |
+
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
150 |
+
)
|
151 |
+
|
152 |
+
return x
|
153 |
+
|
154 |
+
|
155 |
+
class ResidualConvUnit(nn.Module):
|
156 |
+
"""Residual convolution module.
|
157 |
+
"""
|
158 |
+
|
159 |
+
def __init__(self, features):
|
160 |
+
"""Init.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
features (int): number of features
|
164 |
+
"""
|
165 |
+
super().__init__()
|
166 |
+
|
167 |
+
self.conv1 = nn.Conv2d(
|
168 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
169 |
+
)
|
170 |
+
|
171 |
+
self.conv2 = nn.Conv2d(
|
172 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
173 |
+
)
|
174 |
+
|
175 |
+
self.relu = nn.ReLU(inplace=True)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
"""Forward pass.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
x (tensor): input
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
tensor: output
|
185 |
+
"""
|
186 |
+
out = self.relu(x)
|
187 |
+
out = self.conv1(out)
|
188 |
+
out = self.relu(out)
|
189 |
+
out = self.conv2(out)
|
190 |
+
|
191 |
+
return out + x
|
192 |
+
|
193 |
+
|
194 |
+
class FeatureFusionBlock(nn.Module):
|
195 |
+
"""Feature fusion block.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self, features):
|
199 |
+
"""Init.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
features (int): number of features
|
203 |
+
"""
|
204 |
+
super(FeatureFusionBlock, self).__init__()
|
205 |
+
|
206 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
207 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
208 |
+
|
209 |
+
def forward(self, *xs):
|
210 |
+
"""Forward pass.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
tensor: output
|
214 |
+
"""
|
215 |
+
output = xs[0]
|
216 |
+
|
217 |
+
if len(xs) == 2:
|
218 |
+
output += self.resConfUnit1(xs[1])
|
219 |
+
|
220 |
+
output = self.resConfUnit2(output)
|
221 |
+
|
222 |
+
output = nn.functional.interpolate(
|
223 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
224 |
+
)
|
225 |
+
|
226 |
+
return output
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
class ResidualConvUnit_custom(nn.Module):
|
232 |
+
"""Residual convolution module.
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, features, activation, bn):
|
236 |
+
"""Init.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
features (int): number of features
|
240 |
+
"""
|
241 |
+
super().__init__()
|
242 |
+
|
243 |
+
self.bn = bn
|
244 |
+
|
245 |
+
self.groups=1
|
246 |
+
|
247 |
+
self.conv1 = nn.Conv2d(
|
248 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
249 |
+
)
|
250 |
+
|
251 |
+
self.conv2 = nn.Conv2d(
|
252 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
253 |
+
)
|
254 |
+
|
255 |
+
if self.bn==True:
|
256 |
+
self.bn1 = nn.BatchNorm2d(features)
|
257 |
+
self.bn2 = nn.BatchNorm2d(features)
|
258 |
+
|
259 |
+
self.activation = activation
|
260 |
+
|
261 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
"""Forward pass.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
x (tensor): input
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
tensor: output
|
271 |
+
"""
|
272 |
+
|
273 |
+
out = self.activation(x)
|
274 |
+
out = self.conv1(out)
|
275 |
+
if self.bn==True:
|
276 |
+
out = self.bn1(out)
|
277 |
+
|
278 |
+
out = self.activation(out)
|
279 |
+
out = self.conv2(out)
|
280 |
+
if self.bn==True:
|
281 |
+
out = self.bn2(out)
|
282 |
+
|
283 |
+
if self.groups > 1:
|
284 |
+
out = self.conv_merge(out)
|
285 |
+
|
286 |
+
return self.skip_add.add(out, x)
|
287 |
+
|
288 |
+
# return out + x
|
289 |
+
|
290 |
+
|
291 |
+
class FeatureFusionBlock_custom(nn.Module):
|
292 |
+
"""Feature fusion block.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
|
296 |
+
"""Init.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
features (int): number of features
|
300 |
+
"""
|
301 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
302 |
+
|
303 |
+
self.deconv = deconv
|
304 |
+
self.align_corners = align_corners
|
305 |
+
|
306 |
+
self.groups=1
|
307 |
+
|
308 |
+
self.expand = expand
|
309 |
+
out_features = features
|
310 |
+
if self.expand==True:
|
311 |
+
out_features = features//2
|
312 |
+
|
313 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
314 |
+
|
315 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
316 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
317 |
+
|
318 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
319 |
+
|
320 |
+
def forward(self, *xs):
|
321 |
+
"""Forward pass.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
tensor: output
|
325 |
+
"""
|
326 |
+
output = xs[0]
|
327 |
+
|
328 |
+
if len(xs) == 2:
|
329 |
+
res = self.resConfUnit1(xs[1])
|
330 |
+
output = self.skip_add.add(output, res)
|
331 |
+
# output += res
|
332 |
+
|
333 |
+
output = self.resConfUnit2(output)
|
334 |
+
|
335 |
+
output = nn.functional.interpolate(
|
336 |
+
output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
|
337 |
+
)
|
338 |
+
|
339 |
+
output = self.out_conv(output)
|
340 |
+
|
341 |
+
return output
|
342 |
+
|
annotator/midas/midas/dpt_depth.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .base_model import BaseModel
|
6 |
+
from .blocks import (
|
7 |
+
FeatureFusionBlock,
|
8 |
+
FeatureFusionBlock_custom,
|
9 |
+
Interpolate,
|
10 |
+
_make_encoder,
|
11 |
+
forward_vit,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def _make_fusion_block(features, use_bn):
|
16 |
+
return FeatureFusionBlock_custom(
|
17 |
+
features,
|
18 |
+
nn.ReLU(False),
|
19 |
+
deconv=False,
|
20 |
+
bn=use_bn,
|
21 |
+
expand=False,
|
22 |
+
align_corners=True,
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
class DPT(BaseModel):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
head,
|
30 |
+
features=256,
|
31 |
+
backbone="vitb_rn50_384",
|
32 |
+
readout="project",
|
33 |
+
channels_last=False,
|
34 |
+
use_bn=False,
|
35 |
+
):
|
36 |
+
|
37 |
+
super(DPT, self).__init__()
|
38 |
+
|
39 |
+
self.channels_last = channels_last
|
40 |
+
|
41 |
+
hooks = {
|
42 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
43 |
+
"vitb16_384": [2, 5, 8, 11],
|
44 |
+
"vitl16_384": [5, 11, 17, 23],
|
45 |
+
}
|
46 |
+
|
47 |
+
# Instantiate backbone and reassemble blocks
|
48 |
+
self.pretrained, self.scratch = _make_encoder(
|
49 |
+
backbone,
|
50 |
+
features,
|
51 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
52 |
+
groups=1,
|
53 |
+
expand=False,
|
54 |
+
exportable=False,
|
55 |
+
hooks=hooks[backbone],
|
56 |
+
use_readout=readout,
|
57 |
+
)
|
58 |
+
|
59 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
60 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
61 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
62 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
63 |
+
|
64 |
+
self.scratch.output_conv = head
|
65 |
+
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
if self.channels_last == True:
|
69 |
+
x.contiguous(memory_format=torch.channels_last)
|
70 |
+
|
71 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
72 |
+
|
73 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
74 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
75 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
76 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
77 |
+
|
78 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
79 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
80 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
81 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
82 |
+
|
83 |
+
out = self.scratch.output_conv(path_1)
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
class DPTDepthModel(DPT):
|
89 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
90 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
91 |
+
|
92 |
+
head = nn.Sequential(
|
93 |
+
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
94 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
95 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
96 |
+
nn.ReLU(True),
|
97 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
98 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
99 |
+
nn.Identity(),
|
100 |
+
)
|
101 |
+
|
102 |
+
super().__init__(head, **kwargs)
|
103 |
+
|
104 |
+
if path is not None:
|
105 |
+
self.load(path)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return super().forward(x).squeeze(dim=1)
|
109 |
+
|
annotator/midas/midas/midas_net.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
17 |
+
"""Init.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
path (str, optional): Path to saved model. Defaults to None.
|
21 |
+
features (int, optional): Number of features. Defaults to 256.
|
22 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
23 |
+
"""
|
24 |
+
print("Loading weights: ", path)
|
25 |
+
|
26 |
+
super(MidasNet, self).__init__()
|
27 |
+
|
28 |
+
use_pretrained = False if path is None else True
|
29 |
+
|
30 |
+
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
31 |
+
|
32 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
33 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
34 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
35 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
36 |
+
|
37 |
+
self.scratch.output_conv = nn.Sequential(
|
38 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
39 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
40 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
41 |
+
nn.ReLU(True),
|
42 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
43 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
44 |
+
)
|
45 |
+
|
46 |
+
if path:
|
47 |
+
self.load(path)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
"""Forward pass.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
x (tensor): input data (image)
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tensor: depth
|
57 |
+
"""
|
58 |
+
|
59 |
+
layer_1 = self.pretrained.layer1(x)
|
60 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
61 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
62 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
63 |
+
|
64 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
65 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
66 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
67 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
68 |
+
|
69 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
70 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
71 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
72 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
73 |
+
|
74 |
+
out = self.scratch.output_conv(path_1)
|
75 |
+
|
76 |
+
return torch.squeeze(out, dim=1)
|
annotator/midas/midas/midas_net_custom.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet_small(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
17 |
+
blocks={'expand': True}):
|
18 |
+
"""Init.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
path (str, optional): Path to saved model. Defaults to None.
|
22 |
+
features (int, optional): Number of features. Defaults to 256.
|
23 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
24 |
+
"""
|
25 |
+
print("Loading weights: ", path)
|
26 |
+
|
27 |
+
super(MidasNet_small, self).__init__()
|
28 |
+
|
29 |
+
use_pretrained = False if path else True
|
30 |
+
|
31 |
+
self.channels_last = channels_last
|
32 |
+
self.blocks = blocks
|
33 |
+
self.backbone = backbone
|
34 |
+
|
35 |
+
self.groups = 1
|
36 |
+
|
37 |
+
features1=features
|
38 |
+
features2=features
|
39 |
+
features3=features
|
40 |
+
features4=features
|
41 |
+
self.expand = False
|
42 |
+
if "expand" in self.blocks and self.blocks['expand'] == True:
|
43 |
+
self.expand = True
|
44 |
+
features1=features
|
45 |
+
features2=features*2
|
46 |
+
features3=features*4
|
47 |
+
features4=features*8
|
48 |
+
|
49 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
50 |
+
|
51 |
+
self.scratch.activation = nn.ReLU(False)
|
52 |
+
|
53 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
54 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
55 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
56 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
57 |
+
|
58 |
+
|
59 |
+
self.scratch.output_conv = nn.Sequential(
|
60 |
+
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
61 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
62 |
+
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
63 |
+
self.scratch.activation,
|
64 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
65 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
66 |
+
nn.Identity(),
|
67 |
+
)
|
68 |
+
|
69 |
+
if path:
|
70 |
+
self.load(path)
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
"""Forward pass.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
x (tensor): input data (image)
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
tensor: depth
|
81 |
+
"""
|
82 |
+
if self.channels_last==True:
|
83 |
+
print("self.channels_last = ", self.channels_last)
|
84 |
+
x.contiguous(memory_format=torch.channels_last)
|
85 |
+
|
86 |
+
|
87 |
+
layer_1 = self.pretrained.layer1(x)
|
88 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
89 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
90 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
91 |
+
|
92 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
93 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
94 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
95 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
96 |
+
|
97 |
+
|
98 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
99 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
100 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
101 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
102 |
+
|
103 |
+
out = self.scratch.output_conv(path_1)
|
104 |
+
|
105 |
+
return torch.squeeze(out, dim=1)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def fuse_model(m):
|
110 |
+
prev_previous_type = nn.Identity()
|
111 |
+
prev_previous_name = ''
|
112 |
+
previous_type = nn.Identity()
|
113 |
+
previous_name = ''
|
114 |
+
for name, module in m.named_modules():
|
115 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
116 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
117 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
118 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
119 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
120 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
121 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
122 |
+
# print("FUSED ", previous_name, name)
|
123 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
124 |
+
|
125 |
+
prev_previous_type = previous_type
|
126 |
+
prev_previous_name = previous_name
|
127 |
+
previous_type = type(module)
|
128 |
+
previous_name = name
|
annotator/midas/midas/transforms.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
7 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
sample (dict): sample
|
11 |
+
size (tuple): image size
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
tuple: new size
|
15 |
+
"""
|
16 |
+
shape = list(sample["disparity"].shape)
|
17 |
+
|
18 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
19 |
+
return sample
|
20 |
+
|
21 |
+
scale = [0, 0]
|
22 |
+
scale[0] = size[0] / shape[0]
|
23 |
+
scale[1] = size[1] / shape[1]
|
24 |
+
|
25 |
+
scale = max(scale)
|
26 |
+
|
27 |
+
shape[0] = math.ceil(scale * shape[0])
|
28 |
+
shape[1] = math.ceil(scale * shape[1])
|
29 |
+
|
30 |
+
# resize
|
31 |
+
sample["image"] = cv2.resize(
|
32 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
33 |
+
)
|
34 |
+
|
35 |
+
sample["disparity"] = cv2.resize(
|
36 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
37 |
+
)
|
38 |
+
sample["mask"] = cv2.resize(
|
39 |
+
sample["mask"].astype(np.float32),
|
40 |
+
tuple(shape[::-1]),
|
41 |
+
interpolation=cv2.INTER_NEAREST,
|
42 |
+
)
|
43 |
+
sample["mask"] = sample["mask"].astype(bool)
|
44 |
+
|
45 |
+
return tuple(shape)
|
46 |
+
|
47 |
+
|
48 |
+
class Resize(object):
|
49 |
+
"""Resize sample to given size (width, height).
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
width,
|
55 |
+
height,
|
56 |
+
resize_target=True,
|
57 |
+
keep_aspect_ratio=False,
|
58 |
+
ensure_multiple_of=1,
|
59 |
+
resize_method="lower_bound",
|
60 |
+
image_interpolation_method=cv2.INTER_AREA,
|
61 |
+
):
|
62 |
+
"""Init.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
width (int): desired output width
|
66 |
+
height (int): desired output height
|
67 |
+
resize_target (bool, optional):
|
68 |
+
True: Resize the full sample (image, mask, target).
|
69 |
+
False: Resize image only.
|
70 |
+
Defaults to True.
|
71 |
+
keep_aspect_ratio (bool, optional):
|
72 |
+
True: Keep the aspect ratio of the input sample.
|
73 |
+
Output sample might not have the given width and height, and
|
74 |
+
resize behaviour depends on the parameter 'resize_method'.
|
75 |
+
Defaults to False.
|
76 |
+
ensure_multiple_of (int, optional):
|
77 |
+
Output width and height is constrained to be multiple of this parameter.
|
78 |
+
Defaults to 1.
|
79 |
+
resize_method (str, optional):
|
80 |
+
"lower_bound": Output will be at least as large as the given size.
|
81 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
82 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
83 |
+
Defaults to "lower_bound".
|
84 |
+
"""
|
85 |
+
self.__width = width
|
86 |
+
self.__height = height
|
87 |
+
|
88 |
+
self.__resize_target = resize_target
|
89 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
90 |
+
self.__multiple_of = ensure_multiple_of
|
91 |
+
self.__resize_method = resize_method
|
92 |
+
self.__image_interpolation_method = image_interpolation_method
|
93 |
+
|
94 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
95 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
96 |
+
|
97 |
+
if max_val is not None and y > max_val:
|
98 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
99 |
+
|
100 |
+
if y < min_val:
|
101 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
102 |
+
|
103 |
+
return y
|
104 |
+
|
105 |
+
def get_size(self, width, height):
|
106 |
+
# determine new height and width
|
107 |
+
scale_height = self.__height / height
|
108 |
+
scale_width = self.__width / width
|
109 |
+
|
110 |
+
if self.__keep_aspect_ratio:
|
111 |
+
if self.__resize_method == "lower_bound":
|
112 |
+
# scale such that output size is lower bound
|
113 |
+
if scale_width > scale_height:
|
114 |
+
# fit width
|
115 |
+
scale_height = scale_width
|
116 |
+
else:
|
117 |
+
# fit height
|
118 |
+
scale_width = scale_height
|
119 |
+
elif self.__resize_method == "upper_bound":
|
120 |
+
# scale such that output size is upper bound
|
121 |
+
if scale_width < scale_height:
|
122 |
+
# fit width
|
123 |
+
scale_height = scale_width
|
124 |
+
else:
|
125 |
+
# fit height
|
126 |
+
scale_width = scale_height
|
127 |
+
elif self.__resize_method == "minimal":
|
128 |
+
# scale as least as possbile
|
129 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
130 |
+
# fit width
|
131 |
+
scale_height = scale_width
|
132 |
+
else:
|
133 |
+
# fit height
|
134 |
+
scale_width = scale_height
|
135 |
+
else:
|
136 |
+
raise ValueError(
|
137 |
+
f"resize_method {self.__resize_method} not implemented"
|
138 |
+
)
|
139 |
+
|
140 |
+
if self.__resize_method == "lower_bound":
|
141 |
+
new_height = self.constrain_to_multiple_of(
|
142 |
+
scale_height * height, min_val=self.__height
|
143 |
+
)
|
144 |
+
new_width = self.constrain_to_multiple_of(
|
145 |
+
scale_width * width, min_val=self.__width
|
146 |
+
)
|
147 |
+
elif self.__resize_method == "upper_bound":
|
148 |
+
new_height = self.constrain_to_multiple_of(
|
149 |
+
scale_height * height, max_val=self.__height
|
150 |
+
)
|
151 |
+
new_width = self.constrain_to_multiple_of(
|
152 |
+
scale_width * width, max_val=self.__width
|
153 |
+
)
|
154 |
+
elif self.__resize_method == "minimal":
|
155 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
156 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
157 |
+
else:
|
158 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
159 |
+
|
160 |
+
return (new_width, new_height)
|
161 |
+
|
162 |
+
def __call__(self, sample):
|
163 |
+
width, height = self.get_size(
|
164 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
165 |
+
)
|
166 |
+
|
167 |
+
# resize sample
|
168 |
+
sample["image"] = cv2.resize(
|
169 |
+
sample["image"],
|
170 |
+
(width, height),
|
171 |
+
interpolation=self.__image_interpolation_method,
|
172 |
+
)
|
173 |
+
|
174 |
+
if self.__resize_target:
|
175 |
+
if "disparity" in sample:
|
176 |
+
sample["disparity"] = cv2.resize(
|
177 |
+
sample["disparity"],
|
178 |
+
(width, height),
|
179 |
+
interpolation=cv2.INTER_NEAREST,
|
180 |
+
)
|
181 |
+
|
182 |
+
if "depth" in sample:
|
183 |
+
sample["depth"] = cv2.resize(
|
184 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
185 |
+
)
|
186 |
+
|
187 |
+
sample["mask"] = cv2.resize(
|
188 |
+
sample["mask"].astype(np.float32),
|
189 |
+
(width, height),
|
190 |
+
interpolation=cv2.INTER_NEAREST,
|
191 |
+
)
|
192 |
+
sample["mask"] = sample["mask"].astype(bool)
|
193 |
+
|
194 |
+
return sample
|
195 |
+
|
196 |
+
|
197 |
+
class NormalizeImage(object):
|
198 |
+
"""Normlize image by given mean and std.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, mean, std):
|
202 |
+
self.__mean = mean
|
203 |
+
self.__std = std
|
204 |
+
|
205 |
+
def __call__(self, sample):
|
206 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
207 |
+
|
208 |
+
return sample
|
209 |
+
|
210 |
+
|
211 |
+
class PrepareForNet(object):
|
212 |
+
"""Prepare sample for usage as network input.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self):
|
216 |
+
pass
|
217 |
+
|
218 |
+
def __call__(self, sample):
|
219 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
220 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
221 |
+
|
222 |
+
if "mask" in sample:
|
223 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
224 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
225 |
+
|
226 |
+
if "disparity" in sample:
|
227 |
+
disparity = sample["disparity"].astype(np.float32)
|
228 |
+
sample["disparity"] = np.ascontiguousarray(disparity)
|
229 |
+
|
230 |
+
if "depth" in sample:
|
231 |
+
depth = sample["depth"].astype(np.float32)
|
232 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
233 |
+
|
234 |
+
return sample
|
annotator/midas/midas/vit.py
ADDED
@@ -0,0 +1,491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import timm
|
4 |
+
import types
|
5 |
+
import math
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class Slice(nn.Module):
|
10 |
+
def __init__(self, start_index=1):
|
11 |
+
super(Slice, self).__init__()
|
12 |
+
self.start_index = start_index
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
return x[:, self.start_index :]
|
16 |
+
|
17 |
+
|
18 |
+
class AddReadout(nn.Module):
|
19 |
+
def __init__(self, start_index=1):
|
20 |
+
super(AddReadout, self).__init__()
|
21 |
+
self.start_index = start_index
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
if self.start_index == 2:
|
25 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
26 |
+
else:
|
27 |
+
readout = x[:, 0]
|
28 |
+
return x[:, self.start_index :] + readout.unsqueeze(1)
|
29 |
+
|
30 |
+
|
31 |
+
class ProjectReadout(nn.Module):
|
32 |
+
def __init__(self, in_features, start_index=1):
|
33 |
+
super(ProjectReadout, self).__init__()
|
34 |
+
self.start_index = start_index
|
35 |
+
|
36 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
|
40 |
+
features = torch.cat((x[:, self.start_index :], readout), -1)
|
41 |
+
|
42 |
+
return self.project(features)
|
43 |
+
|
44 |
+
|
45 |
+
class Transpose(nn.Module):
|
46 |
+
def __init__(self, dim0, dim1):
|
47 |
+
super(Transpose, self).__init__()
|
48 |
+
self.dim0 = dim0
|
49 |
+
self.dim1 = dim1
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = x.transpose(self.dim0, self.dim1)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
def forward_vit(pretrained, x):
|
57 |
+
b, c, h, w = x.shape
|
58 |
+
|
59 |
+
glob = pretrained.model.forward_flex(x)
|
60 |
+
|
61 |
+
layer_1 = pretrained.activations["1"]
|
62 |
+
layer_2 = pretrained.activations["2"]
|
63 |
+
layer_3 = pretrained.activations["3"]
|
64 |
+
layer_4 = pretrained.activations["4"]
|
65 |
+
|
66 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
67 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
68 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
69 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
70 |
+
|
71 |
+
unflatten = nn.Sequential(
|
72 |
+
nn.Unflatten(
|
73 |
+
2,
|
74 |
+
torch.Size(
|
75 |
+
[
|
76 |
+
h // pretrained.model.patch_size[1],
|
77 |
+
w // pretrained.model.patch_size[0],
|
78 |
+
]
|
79 |
+
),
|
80 |
+
)
|
81 |
+
)
|
82 |
+
|
83 |
+
if layer_1.ndim == 3:
|
84 |
+
layer_1 = unflatten(layer_1)
|
85 |
+
if layer_2.ndim == 3:
|
86 |
+
layer_2 = unflatten(layer_2)
|
87 |
+
if layer_3.ndim == 3:
|
88 |
+
layer_3 = unflatten(layer_3)
|
89 |
+
if layer_4.ndim == 3:
|
90 |
+
layer_4 = unflatten(layer_4)
|
91 |
+
|
92 |
+
layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
|
93 |
+
layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
|
94 |
+
layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
|
95 |
+
layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
|
96 |
+
|
97 |
+
return layer_1, layer_2, layer_3, layer_4
|
98 |
+
|
99 |
+
|
100 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
101 |
+
posemb_tok, posemb_grid = (
|
102 |
+
posemb[:, : self.start_index],
|
103 |
+
posemb[0, self.start_index :],
|
104 |
+
)
|
105 |
+
|
106 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
107 |
+
|
108 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
109 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
110 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
111 |
+
|
112 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
113 |
+
|
114 |
+
return posemb
|
115 |
+
|
116 |
+
|
117 |
+
def forward_flex(self, x):
|
118 |
+
b, c, h, w = x.shape
|
119 |
+
|
120 |
+
pos_embed = self._resize_pos_embed(
|
121 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
122 |
+
)
|
123 |
+
|
124 |
+
B = x.shape[0]
|
125 |
+
|
126 |
+
if hasattr(self.patch_embed, "backbone"):
|
127 |
+
x = self.patch_embed.backbone(x)
|
128 |
+
if isinstance(x, (list, tuple)):
|
129 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
130 |
+
|
131 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
132 |
+
|
133 |
+
if getattr(self, "dist_token", None) is not None:
|
134 |
+
cls_tokens = self.cls_token.expand(
|
135 |
+
B, -1, -1
|
136 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
137 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
138 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
139 |
+
else:
|
140 |
+
cls_tokens = self.cls_token.expand(
|
141 |
+
B, -1, -1
|
142 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
143 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
144 |
+
|
145 |
+
x = x + pos_embed
|
146 |
+
x = self.pos_drop(x)
|
147 |
+
|
148 |
+
for blk in self.blocks:
|
149 |
+
x = blk(x)
|
150 |
+
|
151 |
+
x = self.norm(x)
|
152 |
+
|
153 |
+
return x
|
154 |
+
|
155 |
+
|
156 |
+
activations = {}
|
157 |
+
|
158 |
+
|
159 |
+
def get_activation(name):
|
160 |
+
def hook(model, input, output):
|
161 |
+
activations[name] = output
|
162 |
+
|
163 |
+
return hook
|
164 |
+
|
165 |
+
|
166 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
167 |
+
if use_readout == "ignore":
|
168 |
+
readout_oper = [Slice(start_index)] * len(features)
|
169 |
+
elif use_readout == "add":
|
170 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
171 |
+
elif use_readout == "project":
|
172 |
+
readout_oper = [
|
173 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
174 |
+
]
|
175 |
+
else:
|
176 |
+
assert (
|
177 |
+
False
|
178 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
179 |
+
|
180 |
+
return readout_oper
|
181 |
+
|
182 |
+
|
183 |
+
def _make_vit_b16_backbone(
|
184 |
+
model,
|
185 |
+
features=[96, 192, 384, 768],
|
186 |
+
size=[384, 384],
|
187 |
+
hooks=[2, 5, 8, 11],
|
188 |
+
vit_features=768,
|
189 |
+
use_readout="ignore",
|
190 |
+
start_index=1,
|
191 |
+
):
|
192 |
+
pretrained = nn.Module()
|
193 |
+
|
194 |
+
pretrained.model = model
|
195 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
196 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
197 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
198 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
199 |
+
|
200 |
+
pretrained.activations = activations
|
201 |
+
|
202 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
203 |
+
|
204 |
+
# 32, 48, 136, 384
|
205 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
206 |
+
readout_oper[0],
|
207 |
+
Transpose(1, 2),
|
208 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
209 |
+
nn.Conv2d(
|
210 |
+
in_channels=vit_features,
|
211 |
+
out_channels=features[0],
|
212 |
+
kernel_size=1,
|
213 |
+
stride=1,
|
214 |
+
padding=0,
|
215 |
+
),
|
216 |
+
nn.ConvTranspose2d(
|
217 |
+
in_channels=features[0],
|
218 |
+
out_channels=features[0],
|
219 |
+
kernel_size=4,
|
220 |
+
stride=4,
|
221 |
+
padding=0,
|
222 |
+
bias=True,
|
223 |
+
dilation=1,
|
224 |
+
groups=1,
|
225 |
+
),
|
226 |
+
)
|
227 |
+
|
228 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
229 |
+
readout_oper[1],
|
230 |
+
Transpose(1, 2),
|
231 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
232 |
+
nn.Conv2d(
|
233 |
+
in_channels=vit_features,
|
234 |
+
out_channels=features[1],
|
235 |
+
kernel_size=1,
|
236 |
+
stride=1,
|
237 |
+
padding=0,
|
238 |
+
),
|
239 |
+
nn.ConvTranspose2d(
|
240 |
+
in_channels=features[1],
|
241 |
+
out_channels=features[1],
|
242 |
+
kernel_size=2,
|
243 |
+
stride=2,
|
244 |
+
padding=0,
|
245 |
+
bias=True,
|
246 |
+
dilation=1,
|
247 |
+
groups=1,
|
248 |
+
),
|
249 |
+
)
|
250 |
+
|
251 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
252 |
+
readout_oper[2],
|
253 |
+
Transpose(1, 2),
|
254 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
255 |
+
nn.Conv2d(
|
256 |
+
in_channels=vit_features,
|
257 |
+
out_channels=features[2],
|
258 |
+
kernel_size=1,
|
259 |
+
stride=1,
|
260 |
+
padding=0,
|
261 |
+
),
|
262 |
+
)
|
263 |
+
|
264 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
265 |
+
readout_oper[3],
|
266 |
+
Transpose(1, 2),
|
267 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
268 |
+
nn.Conv2d(
|
269 |
+
in_channels=vit_features,
|
270 |
+
out_channels=features[3],
|
271 |
+
kernel_size=1,
|
272 |
+
stride=1,
|
273 |
+
padding=0,
|
274 |
+
),
|
275 |
+
nn.Conv2d(
|
276 |
+
in_channels=features[3],
|
277 |
+
out_channels=features[3],
|
278 |
+
kernel_size=3,
|
279 |
+
stride=2,
|
280 |
+
padding=1,
|
281 |
+
),
|
282 |
+
)
|
283 |
+
|
284 |
+
pretrained.model.start_index = start_index
|
285 |
+
pretrained.model.patch_size = [16, 16]
|
286 |
+
|
287 |
+
# We inject this function into the VisionTransformer instances so that
|
288 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
289 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
290 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
291 |
+
_resize_pos_embed, pretrained.model
|
292 |
+
)
|
293 |
+
|
294 |
+
return pretrained
|
295 |
+
|
296 |
+
|
297 |
+
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
298 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
299 |
+
|
300 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
301 |
+
return _make_vit_b16_backbone(
|
302 |
+
model,
|
303 |
+
features=[256, 512, 1024, 1024],
|
304 |
+
hooks=hooks,
|
305 |
+
vit_features=1024,
|
306 |
+
use_readout=use_readout,
|
307 |
+
)
|
308 |
+
|
309 |
+
|
310 |
+
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
311 |
+
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
312 |
+
|
313 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
314 |
+
return _make_vit_b16_backbone(
|
315 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
|
320 |
+
model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
|
321 |
+
|
322 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
323 |
+
return _make_vit_b16_backbone(
|
324 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
325 |
+
)
|
326 |
+
|
327 |
+
|
328 |
+
def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
|
329 |
+
model = timm.create_model(
|
330 |
+
"vit_deit_base_distilled_patch16_384", pretrained=pretrained
|
331 |
+
)
|
332 |
+
|
333 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
334 |
+
return _make_vit_b16_backbone(
|
335 |
+
model,
|
336 |
+
features=[96, 192, 384, 768],
|
337 |
+
hooks=hooks,
|
338 |
+
use_readout=use_readout,
|
339 |
+
start_index=2,
|
340 |
+
)
|
341 |
+
|
342 |
+
|
343 |
+
def _make_vit_b_rn50_backbone(
|
344 |
+
model,
|
345 |
+
features=[256, 512, 768, 768],
|
346 |
+
size=[384, 384],
|
347 |
+
hooks=[0, 1, 8, 11],
|
348 |
+
vit_features=768,
|
349 |
+
use_vit_only=False,
|
350 |
+
use_readout="ignore",
|
351 |
+
start_index=1,
|
352 |
+
):
|
353 |
+
pretrained = nn.Module()
|
354 |
+
|
355 |
+
pretrained.model = model
|
356 |
+
|
357 |
+
if use_vit_only == True:
|
358 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
359 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
360 |
+
else:
|
361 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
362 |
+
get_activation("1")
|
363 |
+
)
|
364 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
365 |
+
get_activation("2")
|
366 |
+
)
|
367 |
+
|
368 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
369 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
370 |
+
|
371 |
+
pretrained.activations = activations
|
372 |
+
|
373 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
374 |
+
|
375 |
+
if use_vit_only == True:
|
376 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
377 |
+
readout_oper[0],
|
378 |
+
Transpose(1, 2),
|
379 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
380 |
+
nn.Conv2d(
|
381 |
+
in_channels=vit_features,
|
382 |
+
out_channels=features[0],
|
383 |
+
kernel_size=1,
|
384 |
+
stride=1,
|
385 |
+
padding=0,
|
386 |
+
),
|
387 |
+
nn.ConvTranspose2d(
|
388 |
+
in_channels=features[0],
|
389 |
+
out_channels=features[0],
|
390 |
+
kernel_size=4,
|
391 |
+
stride=4,
|
392 |
+
padding=0,
|
393 |
+
bias=True,
|
394 |
+
dilation=1,
|
395 |
+
groups=1,
|
396 |
+
),
|
397 |
+
)
|
398 |
+
|
399 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
400 |
+
readout_oper[1],
|
401 |
+
Transpose(1, 2),
|
402 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
403 |
+
nn.Conv2d(
|
404 |
+
in_channels=vit_features,
|
405 |
+
out_channels=features[1],
|
406 |
+
kernel_size=1,
|
407 |
+
stride=1,
|
408 |
+
padding=0,
|
409 |
+
),
|
410 |
+
nn.ConvTranspose2d(
|
411 |
+
in_channels=features[1],
|
412 |
+
out_channels=features[1],
|
413 |
+
kernel_size=2,
|
414 |
+
stride=2,
|
415 |
+
padding=0,
|
416 |
+
bias=True,
|
417 |
+
dilation=1,
|
418 |
+
groups=1,
|
419 |
+
),
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
423 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
424 |
+
)
|
425 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
426 |
+
nn.Identity(), nn.Identity(), nn.Identity()
|
427 |
+
)
|
428 |
+
|
429 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
430 |
+
readout_oper[2],
|
431 |
+
Transpose(1, 2),
|
432 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
433 |
+
nn.Conv2d(
|
434 |
+
in_channels=vit_features,
|
435 |
+
out_channels=features[2],
|
436 |
+
kernel_size=1,
|
437 |
+
stride=1,
|
438 |
+
padding=0,
|
439 |
+
),
|
440 |
+
)
|
441 |
+
|
442 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
443 |
+
readout_oper[3],
|
444 |
+
Transpose(1, 2),
|
445 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
446 |
+
nn.Conv2d(
|
447 |
+
in_channels=vit_features,
|
448 |
+
out_channels=features[3],
|
449 |
+
kernel_size=1,
|
450 |
+
stride=1,
|
451 |
+
padding=0,
|
452 |
+
),
|
453 |
+
nn.Conv2d(
|
454 |
+
in_channels=features[3],
|
455 |
+
out_channels=features[3],
|
456 |
+
kernel_size=3,
|
457 |
+
stride=2,
|
458 |
+
padding=1,
|
459 |
+
),
|
460 |
+
)
|
461 |
+
|
462 |
+
pretrained.model.start_index = start_index
|
463 |
+
pretrained.model.patch_size = [16, 16]
|
464 |
+
|
465 |
+
# We inject this function into the VisionTransformer instances so that
|
466 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
467 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
468 |
+
|
469 |
+
# We inject this function into the VisionTransformer instances so that
|
470 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
471 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
472 |
+
_resize_pos_embed, pretrained.model
|
473 |
+
)
|
474 |
+
|
475 |
+
return pretrained
|
476 |
+
|
477 |
+
|
478 |
+
def _make_pretrained_vitb_rn50_384(
|
479 |
+
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
480 |
+
):
|
481 |
+
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
482 |
+
|
483 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
484 |
+
return _make_vit_b_rn50_backbone(
|
485 |
+
model,
|
486 |
+
features=[256, 512, 768, 768],
|
487 |
+
size=[384, 384],
|
488 |
+
hooks=hooks,
|
489 |
+
use_vit_only=use_vit_only,
|
490 |
+
use_readout=use_readout,
|
491 |
+
)
|
annotator/midas/utils.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utils for monoDepth."""
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def read_pfm(path):
|
10 |
+
"""Read pfm file.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
path (str): path to file
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
tuple: (data, scale)
|
17 |
+
"""
|
18 |
+
with open(path, "rb") as file:
|
19 |
+
|
20 |
+
color = None
|
21 |
+
width = None
|
22 |
+
height = None
|
23 |
+
scale = None
|
24 |
+
endian = None
|
25 |
+
|
26 |
+
header = file.readline().rstrip()
|
27 |
+
if header.decode("ascii") == "PF":
|
28 |
+
color = True
|
29 |
+
elif header.decode("ascii") == "Pf":
|
30 |
+
color = False
|
31 |
+
else:
|
32 |
+
raise Exception("Not a PFM file: " + path)
|
33 |
+
|
34 |
+
dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
|
35 |
+
if dim_match:
|
36 |
+
width, height = list(map(int, dim_match.groups()))
|
37 |
+
else:
|
38 |
+
raise Exception("Malformed PFM header.")
|
39 |
+
|
40 |
+
scale = float(file.readline().decode("ascii").rstrip())
|
41 |
+
if scale < 0:
|
42 |
+
# little-endian
|
43 |
+
endian = "<"
|
44 |
+
scale = -scale
|
45 |
+
else:
|
46 |
+
# big-endian
|
47 |
+
endian = ">"
|
48 |
+
|
49 |
+
data = np.fromfile(file, endian + "f")
|
50 |
+
shape = (height, width, 3) if color else (height, width)
|
51 |
+
|
52 |
+
data = np.reshape(data, shape)
|
53 |
+
data = np.flipud(data)
|
54 |
+
|
55 |
+
return data, scale
|
56 |
+
|
57 |
+
|
58 |
+
def write_pfm(path, image, scale=1):
|
59 |
+
"""Write pfm file.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
path (str): pathto file
|
63 |
+
image (array): data
|
64 |
+
scale (int, optional): Scale. Defaults to 1.
|
65 |
+
"""
|
66 |
+
|
67 |
+
with open(path, "wb") as file:
|
68 |
+
color = None
|
69 |
+
|
70 |
+
if image.dtype.name != "float32":
|
71 |
+
raise Exception("Image dtype must be float32.")
|
72 |
+
|
73 |
+
image = np.flipud(image)
|
74 |
+
|
75 |
+
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
76 |
+
color = True
|
77 |
+
elif (
|
78 |
+
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
|
79 |
+
): # greyscale
|
80 |
+
color = False
|
81 |
+
else:
|
82 |
+
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
|
83 |
+
|
84 |
+
file.write("PF\n" if color else "Pf\n".encode())
|
85 |
+
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
|
86 |
+
|
87 |
+
endian = image.dtype.byteorder
|
88 |
+
|
89 |
+
if endian == "<" or endian == "=" and sys.byteorder == "little":
|
90 |
+
scale = -scale
|
91 |
+
|
92 |
+
file.write("%f\n".encode() % scale)
|
93 |
+
|
94 |
+
image.tofile(file)
|
95 |
+
|
96 |
+
|
97 |
+
def read_image(path):
|
98 |
+
"""Read image and output RGB image (0-1).
|
99 |
+
|
100 |
+
Args:
|
101 |
+
path (str): path to file
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
array: RGB image (0-1)
|
105 |
+
"""
|
106 |
+
img = cv2.imread(path)
|
107 |
+
|
108 |
+
if img.ndim == 2:
|
109 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
110 |
+
|
111 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
112 |
+
|
113 |
+
return img
|
114 |
+
|
115 |
+
|
116 |
+
def resize_image(img):
|
117 |
+
"""Resize image and make it fit for network.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
img (array): image
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
tensor: data ready for network
|
124 |
+
"""
|
125 |
+
height_orig = img.shape[0]
|
126 |
+
width_orig = img.shape[1]
|
127 |
+
|
128 |
+
if width_orig > height_orig:
|
129 |
+
scale = width_orig / 384
|
130 |
+
else:
|
131 |
+
scale = height_orig / 384
|
132 |
+
|
133 |
+
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
134 |
+
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
135 |
+
|
136 |
+
img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
|
137 |
+
|
138 |
+
img_resized = (
|
139 |
+
torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
|
140 |
+
)
|
141 |
+
img_resized = img_resized.unsqueeze(0)
|
142 |
+
|
143 |
+
return img_resized
|
144 |
+
|
145 |
+
|
146 |
+
def resize_depth(depth, width, height):
|
147 |
+
"""Resize depth map and bring to CPU (numpy).
|
148 |
+
|
149 |
+
Args:
|
150 |
+
depth (tensor): depth
|
151 |
+
width (int): image width
|
152 |
+
height (int): image height
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
array: processed depth
|
156 |
+
"""
|
157 |
+
depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
|
158 |
+
|
159 |
+
depth_resized = cv2.resize(
|
160 |
+
depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
|
161 |
+
)
|
162 |
+
|
163 |
+
return depth_resized
|
164 |
+
|
165 |
+
def write_depth(path, depth, bits=1):
|
166 |
+
"""Write depth map to pfm and png file.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
path (str): filepath without extension
|
170 |
+
depth (array): depth
|
171 |
+
"""
|
172 |
+
write_pfm(path + ".pfm", depth.astype(np.float32))
|
173 |
+
|
174 |
+
depth_min = depth.min()
|
175 |
+
depth_max = depth.max()
|
176 |
+
|
177 |
+
max_val = (2**(8*bits))-1
|
178 |
+
|
179 |
+
if depth_max - depth_min > np.finfo("float").eps:
|
180 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
181 |
+
else:
|
182 |
+
out = np.zeros(depth.shape, dtype=depth.type)
|
183 |
+
|
184 |
+
if bits == 1:
|
185 |
+
cv2.imwrite(path + ".png", out.astype("uint8"))
|
186 |
+
elif bits == 2:
|
187 |
+
cv2.imwrite(path + ".png", out.astype("uint16"))
|
188 |
+
|
189 |
+
return
|
annotator/mlsd/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2021-present NAVER Corp.
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
annotator/mlsd/__init__.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MLSD Line Detection
|
2 |
+
# From https://github.com/navervision/mlsd
|
3 |
+
# Apache-2.0 license
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import os
|
9 |
+
|
10 |
+
from einops import rearrange
|
11 |
+
from .models.mbv2_mlsd_tiny import MobileV2_MLSD_Tiny
|
12 |
+
from .models.mbv2_mlsd_large import MobileV2_MLSD_Large
|
13 |
+
from .utils import pred_lines
|
14 |
+
|
15 |
+
from annotator.util import annotator_ckpts_path
|
16 |
+
|
17 |
+
|
18 |
+
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/mlsd_large_512_fp32.pth"
|
19 |
+
|
20 |
+
|
21 |
+
class MLSDdetector:
|
22 |
+
def __init__(self):
|
23 |
+
model_path = os.path.join(annotator_ckpts_path, "mlsd_large_512_fp32.pth")
|
24 |
+
if not os.path.exists(model_path):
|
25 |
+
from basicsr.utils.download_util import load_file_from_url
|
26 |
+
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
27 |
+
model = MobileV2_MLSD_Large()
|
28 |
+
model.load_state_dict(torch.load(model_path), strict=True)
|
29 |
+
self.model = model.cuda().eval()
|
30 |
+
|
31 |
+
def __call__(self, input_image, thr_v, thr_d):
|
32 |
+
assert input_image.ndim == 3
|
33 |
+
img = input_image
|
34 |
+
img_output = np.zeros_like(img)
|
35 |
+
try:
|
36 |
+
with torch.no_grad():
|
37 |
+
lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d)
|
38 |
+
for line in lines:
|
39 |
+
x_start, y_start, x_end, y_end = [int(val) for val in line]
|
40 |
+
cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1)
|
41 |
+
except Exception as e:
|
42 |
+
pass
|
43 |
+
return img_output[:, :, 0]
|
annotator/mlsd/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (1.92 kB). View file
|
|
annotator/mlsd/models/mbv2_mlsd_large.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class BlockTypeA(nn.Module):
|
10 |
+
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
|
11 |
+
super(BlockTypeA, self).__init__()
|
12 |
+
self.conv1 = nn.Sequential(
|
13 |
+
nn.Conv2d(in_c2, out_c2, kernel_size=1),
|
14 |
+
nn.BatchNorm2d(out_c2),
|
15 |
+
nn.ReLU(inplace=True)
|
16 |
+
)
|
17 |
+
self.conv2 = nn.Sequential(
|
18 |
+
nn.Conv2d(in_c1, out_c1, kernel_size=1),
|
19 |
+
nn.BatchNorm2d(out_c1),
|
20 |
+
nn.ReLU(inplace=True)
|
21 |
+
)
|
22 |
+
self.upscale = upscale
|
23 |
+
|
24 |
+
def forward(self, a, b):
|
25 |
+
b = self.conv1(b)
|
26 |
+
a = self.conv2(a)
|
27 |
+
if self.upscale:
|
28 |
+
b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
|
29 |
+
return torch.cat((a, b), dim=1)
|
30 |
+
|
31 |
+
|
32 |
+
class BlockTypeB(nn.Module):
|
33 |
+
def __init__(self, in_c, out_c):
|
34 |
+
super(BlockTypeB, self).__init__()
|
35 |
+
self.conv1 = nn.Sequential(
|
36 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
37 |
+
nn.BatchNorm2d(in_c),
|
38 |
+
nn.ReLU()
|
39 |
+
)
|
40 |
+
self.conv2 = nn.Sequential(
|
41 |
+
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
42 |
+
nn.BatchNorm2d(out_c),
|
43 |
+
nn.ReLU()
|
44 |
+
)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
x = self.conv1(x) + x
|
48 |
+
x = self.conv2(x)
|
49 |
+
return x
|
50 |
+
|
51 |
+
class BlockTypeC(nn.Module):
|
52 |
+
def __init__(self, in_c, out_c):
|
53 |
+
super(BlockTypeC, self).__init__()
|
54 |
+
self.conv1 = nn.Sequential(
|
55 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
|
56 |
+
nn.BatchNorm2d(in_c),
|
57 |
+
nn.ReLU()
|
58 |
+
)
|
59 |
+
self.conv2 = nn.Sequential(
|
60 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
61 |
+
nn.BatchNorm2d(in_c),
|
62 |
+
nn.ReLU()
|
63 |
+
)
|
64 |
+
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
x = self.conv1(x)
|
68 |
+
x = self.conv2(x)
|
69 |
+
x = self.conv3(x)
|
70 |
+
return x
|
71 |
+
|
72 |
+
def _make_divisible(v, divisor, min_value=None):
|
73 |
+
"""
|
74 |
+
This function is taken from the original tf repo.
|
75 |
+
It ensures that all layers have a channel number that is divisible by 8
|
76 |
+
It can be seen here:
|
77 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
78 |
+
:param v:
|
79 |
+
:param divisor:
|
80 |
+
:param min_value:
|
81 |
+
:return:
|
82 |
+
"""
|
83 |
+
if min_value is None:
|
84 |
+
min_value = divisor
|
85 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
86 |
+
# Make sure that round down does not go down by more than 10%.
|
87 |
+
if new_v < 0.9 * v:
|
88 |
+
new_v += divisor
|
89 |
+
return new_v
|
90 |
+
|
91 |
+
|
92 |
+
class ConvBNReLU(nn.Sequential):
|
93 |
+
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
94 |
+
self.channel_pad = out_planes - in_planes
|
95 |
+
self.stride = stride
|
96 |
+
#padding = (kernel_size - 1) // 2
|
97 |
+
|
98 |
+
# TFLite uses slightly different padding than PyTorch
|
99 |
+
if stride == 2:
|
100 |
+
padding = 0
|
101 |
+
else:
|
102 |
+
padding = (kernel_size - 1) // 2
|
103 |
+
|
104 |
+
super(ConvBNReLU, self).__init__(
|
105 |
+
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
106 |
+
nn.BatchNorm2d(out_planes),
|
107 |
+
nn.ReLU6(inplace=True)
|
108 |
+
)
|
109 |
+
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
110 |
+
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
# TFLite uses different padding
|
114 |
+
if self.stride == 2:
|
115 |
+
x = F.pad(x, (0, 1, 0, 1), "constant", 0)
|
116 |
+
#print(x.shape)
|
117 |
+
|
118 |
+
for module in self:
|
119 |
+
if not isinstance(module, nn.MaxPool2d):
|
120 |
+
x = module(x)
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
class InvertedResidual(nn.Module):
|
125 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
126 |
+
super(InvertedResidual, self).__init__()
|
127 |
+
self.stride = stride
|
128 |
+
assert stride in [1, 2]
|
129 |
+
|
130 |
+
hidden_dim = int(round(inp * expand_ratio))
|
131 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
132 |
+
|
133 |
+
layers = []
|
134 |
+
if expand_ratio != 1:
|
135 |
+
# pw
|
136 |
+
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
137 |
+
layers.extend([
|
138 |
+
# dw
|
139 |
+
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
140 |
+
# pw-linear
|
141 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
142 |
+
nn.BatchNorm2d(oup),
|
143 |
+
])
|
144 |
+
self.conv = nn.Sequential(*layers)
|
145 |
+
|
146 |
+
def forward(self, x):
|
147 |
+
if self.use_res_connect:
|
148 |
+
return x + self.conv(x)
|
149 |
+
else:
|
150 |
+
return self.conv(x)
|
151 |
+
|
152 |
+
|
153 |
+
class MobileNetV2(nn.Module):
|
154 |
+
def __init__(self, pretrained=True):
|
155 |
+
"""
|
156 |
+
MobileNet V2 main class
|
157 |
+
Args:
|
158 |
+
num_classes (int): Number of classes
|
159 |
+
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
160 |
+
inverted_residual_setting: Network structure
|
161 |
+
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
162 |
+
Set to 1 to turn off rounding
|
163 |
+
block: Module specifying inverted residual building block for mobilenet
|
164 |
+
"""
|
165 |
+
super(MobileNetV2, self).__init__()
|
166 |
+
|
167 |
+
block = InvertedResidual
|
168 |
+
input_channel = 32
|
169 |
+
last_channel = 1280
|
170 |
+
width_mult = 1.0
|
171 |
+
round_nearest = 8
|
172 |
+
|
173 |
+
inverted_residual_setting = [
|
174 |
+
# t, c, n, s
|
175 |
+
[1, 16, 1, 1],
|
176 |
+
[6, 24, 2, 2],
|
177 |
+
[6, 32, 3, 2],
|
178 |
+
[6, 64, 4, 2],
|
179 |
+
[6, 96, 3, 1],
|
180 |
+
#[6, 160, 3, 2],
|
181 |
+
#[6, 320, 1, 1],
|
182 |
+
]
|
183 |
+
|
184 |
+
# only check the first element, assuming user knows t,c,n,s are required
|
185 |
+
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
186 |
+
raise ValueError("inverted_residual_setting should be non-empty "
|
187 |
+
"or a 4-element list, got {}".format(inverted_residual_setting))
|
188 |
+
|
189 |
+
# building first layer
|
190 |
+
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
191 |
+
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
192 |
+
features = [ConvBNReLU(4, input_channel, stride=2)]
|
193 |
+
# building inverted residual blocks
|
194 |
+
for t, c, n, s in inverted_residual_setting:
|
195 |
+
output_channel = _make_divisible(c * width_mult, round_nearest)
|
196 |
+
for i in range(n):
|
197 |
+
stride = s if i == 0 else 1
|
198 |
+
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
199 |
+
input_channel = output_channel
|
200 |
+
|
201 |
+
self.features = nn.Sequential(*features)
|
202 |
+
self.fpn_selected = [1, 3, 6, 10, 13]
|
203 |
+
# weight initialization
|
204 |
+
for m in self.modules():
|
205 |
+
if isinstance(m, nn.Conv2d):
|
206 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
207 |
+
if m.bias is not None:
|
208 |
+
nn.init.zeros_(m.bias)
|
209 |
+
elif isinstance(m, nn.BatchNorm2d):
|
210 |
+
nn.init.ones_(m.weight)
|
211 |
+
nn.init.zeros_(m.bias)
|
212 |
+
elif isinstance(m, nn.Linear):
|
213 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
214 |
+
nn.init.zeros_(m.bias)
|
215 |
+
if pretrained:
|
216 |
+
self._load_pretrained_model()
|
217 |
+
|
218 |
+
def _forward_impl(self, x):
|
219 |
+
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
220 |
+
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
221 |
+
fpn_features = []
|
222 |
+
for i, f in enumerate(self.features):
|
223 |
+
if i > self.fpn_selected[-1]:
|
224 |
+
break
|
225 |
+
x = f(x)
|
226 |
+
if i in self.fpn_selected:
|
227 |
+
fpn_features.append(x)
|
228 |
+
|
229 |
+
c1, c2, c3, c4, c5 = fpn_features
|
230 |
+
return c1, c2, c3, c4, c5
|
231 |
+
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
return self._forward_impl(x)
|
235 |
+
|
236 |
+
def _load_pretrained_model(self):
|
237 |
+
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
|
238 |
+
model_dict = {}
|
239 |
+
state_dict = self.state_dict()
|
240 |
+
for k, v in pretrain_dict.items():
|
241 |
+
if k in state_dict:
|
242 |
+
model_dict[k] = v
|
243 |
+
state_dict.update(model_dict)
|
244 |
+
self.load_state_dict(state_dict)
|
245 |
+
|
246 |
+
|
247 |
+
class MobileV2_MLSD_Large(nn.Module):
|
248 |
+
def __init__(self):
|
249 |
+
super(MobileV2_MLSD_Large, self).__init__()
|
250 |
+
|
251 |
+
self.backbone = MobileNetV2(pretrained=False)
|
252 |
+
## A, B
|
253 |
+
self.block15 = BlockTypeA(in_c1= 64, in_c2= 96,
|
254 |
+
out_c1= 64, out_c2=64,
|
255 |
+
upscale=False)
|
256 |
+
self.block16 = BlockTypeB(128, 64)
|
257 |
+
|
258 |
+
## A, B
|
259 |
+
self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64,
|
260 |
+
out_c1= 64, out_c2= 64)
|
261 |
+
self.block18 = BlockTypeB(128, 64)
|
262 |
+
|
263 |
+
## A, B
|
264 |
+
self.block19 = BlockTypeA(in_c1=24, in_c2=64,
|
265 |
+
out_c1=64, out_c2=64)
|
266 |
+
self.block20 = BlockTypeB(128, 64)
|
267 |
+
|
268 |
+
## A, B, C
|
269 |
+
self.block21 = BlockTypeA(in_c1=16, in_c2=64,
|
270 |
+
out_c1=64, out_c2=64)
|
271 |
+
self.block22 = BlockTypeB(128, 64)
|
272 |
+
|
273 |
+
self.block23 = BlockTypeC(64, 16)
|
274 |
+
|
275 |
+
def forward(self, x):
|
276 |
+
c1, c2, c3, c4, c5 = self.backbone(x)
|
277 |
+
|
278 |
+
x = self.block15(c4, c5)
|
279 |
+
x = self.block16(x)
|
280 |
+
|
281 |
+
x = self.block17(c3, x)
|
282 |
+
x = self.block18(x)
|
283 |
+
|
284 |
+
x = self.block19(c2, x)
|
285 |
+
x = self.block20(x)
|
286 |
+
|
287 |
+
x = self.block21(c1, x)
|
288 |
+
x = self.block22(x)
|
289 |
+
x = self.block23(x)
|
290 |
+
x = x[:, 7:, :, :]
|
291 |
+
|
292 |
+
return x
|
annotator/mlsd/models/mbv2_mlsd_tiny.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.utils.model_zoo as model_zoo
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
|
9 |
+
class BlockTypeA(nn.Module):
|
10 |
+
def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True):
|
11 |
+
super(BlockTypeA, self).__init__()
|
12 |
+
self.conv1 = nn.Sequential(
|
13 |
+
nn.Conv2d(in_c2, out_c2, kernel_size=1),
|
14 |
+
nn.BatchNorm2d(out_c2),
|
15 |
+
nn.ReLU(inplace=True)
|
16 |
+
)
|
17 |
+
self.conv2 = nn.Sequential(
|
18 |
+
nn.Conv2d(in_c1, out_c1, kernel_size=1),
|
19 |
+
nn.BatchNorm2d(out_c1),
|
20 |
+
nn.ReLU(inplace=True)
|
21 |
+
)
|
22 |
+
self.upscale = upscale
|
23 |
+
|
24 |
+
def forward(self, a, b):
|
25 |
+
b = self.conv1(b)
|
26 |
+
a = self.conv2(a)
|
27 |
+
b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True)
|
28 |
+
return torch.cat((a, b), dim=1)
|
29 |
+
|
30 |
+
|
31 |
+
class BlockTypeB(nn.Module):
|
32 |
+
def __init__(self, in_c, out_c):
|
33 |
+
super(BlockTypeB, self).__init__()
|
34 |
+
self.conv1 = nn.Sequential(
|
35 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
36 |
+
nn.BatchNorm2d(in_c),
|
37 |
+
nn.ReLU()
|
38 |
+
)
|
39 |
+
self.conv2 = nn.Sequential(
|
40 |
+
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
41 |
+
nn.BatchNorm2d(out_c),
|
42 |
+
nn.ReLU()
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
x = self.conv1(x) + x
|
47 |
+
x = self.conv2(x)
|
48 |
+
return x
|
49 |
+
|
50 |
+
class BlockTypeC(nn.Module):
|
51 |
+
def __init__(self, in_c, out_c):
|
52 |
+
super(BlockTypeC, self).__init__()
|
53 |
+
self.conv1 = nn.Sequential(
|
54 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5),
|
55 |
+
nn.BatchNorm2d(in_c),
|
56 |
+
nn.ReLU()
|
57 |
+
)
|
58 |
+
self.conv2 = nn.Sequential(
|
59 |
+
nn.Conv2d(in_c, in_c, kernel_size=3, padding=1),
|
60 |
+
nn.BatchNorm2d(in_c),
|
61 |
+
nn.ReLU()
|
62 |
+
)
|
63 |
+
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1)
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
x = self.conv1(x)
|
67 |
+
x = self.conv2(x)
|
68 |
+
x = self.conv3(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
def _make_divisible(v, divisor, min_value=None):
|
72 |
+
"""
|
73 |
+
This function is taken from the original tf repo.
|
74 |
+
It ensures that all layers have a channel number that is divisible by 8
|
75 |
+
It can be seen here:
|
76 |
+
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
77 |
+
:param v:
|
78 |
+
:param divisor:
|
79 |
+
:param min_value:
|
80 |
+
:return:
|
81 |
+
"""
|
82 |
+
if min_value is None:
|
83 |
+
min_value = divisor
|
84 |
+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
85 |
+
# Make sure that round down does not go down by more than 10%.
|
86 |
+
if new_v < 0.9 * v:
|
87 |
+
new_v += divisor
|
88 |
+
return new_v
|
89 |
+
|
90 |
+
|
91 |
+
class ConvBNReLU(nn.Sequential):
|
92 |
+
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
|
93 |
+
self.channel_pad = out_planes - in_planes
|
94 |
+
self.stride = stride
|
95 |
+
#padding = (kernel_size - 1) // 2
|
96 |
+
|
97 |
+
# TFLite uses slightly different padding than PyTorch
|
98 |
+
if stride == 2:
|
99 |
+
padding = 0
|
100 |
+
else:
|
101 |
+
padding = (kernel_size - 1) // 2
|
102 |
+
|
103 |
+
super(ConvBNReLU, self).__init__(
|
104 |
+
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
|
105 |
+
nn.BatchNorm2d(out_planes),
|
106 |
+
nn.ReLU6(inplace=True)
|
107 |
+
)
|
108 |
+
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
|
109 |
+
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
# TFLite uses different padding
|
113 |
+
if self.stride == 2:
|
114 |
+
x = F.pad(x, (0, 1, 0, 1), "constant", 0)
|
115 |
+
#print(x.shape)
|
116 |
+
|
117 |
+
for module in self:
|
118 |
+
if not isinstance(module, nn.MaxPool2d):
|
119 |
+
x = module(x)
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
class InvertedResidual(nn.Module):
|
124 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
125 |
+
super(InvertedResidual, self).__init__()
|
126 |
+
self.stride = stride
|
127 |
+
assert stride in [1, 2]
|
128 |
+
|
129 |
+
hidden_dim = int(round(inp * expand_ratio))
|
130 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
131 |
+
|
132 |
+
layers = []
|
133 |
+
if expand_ratio != 1:
|
134 |
+
# pw
|
135 |
+
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
|
136 |
+
layers.extend([
|
137 |
+
# dw
|
138 |
+
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
|
139 |
+
# pw-linear
|
140 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
141 |
+
nn.BatchNorm2d(oup),
|
142 |
+
])
|
143 |
+
self.conv = nn.Sequential(*layers)
|
144 |
+
|
145 |
+
def forward(self, x):
|
146 |
+
if self.use_res_connect:
|
147 |
+
return x + self.conv(x)
|
148 |
+
else:
|
149 |
+
return self.conv(x)
|
150 |
+
|
151 |
+
|
152 |
+
class MobileNetV2(nn.Module):
|
153 |
+
def __init__(self, pretrained=True):
|
154 |
+
"""
|
155 |
+
MobileNet V2 main class
|
156 |
+
Args:
|
157 |
+
num_classes (int): Number of classes
|
158 |
+
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
|
159 |
+
inverted_residual_setting: Network structure
|
160 |
+
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
|
161 |
+
Set to 1 to turn off rounding
|
162 |
+
block: Module specifying inverted residual building block for mobilenet
|
163 |
+
"""
|
164 |
+
super(MobileNetV2, self).__init__()
|
165 |
+
|
166 |
+
block = InvertedResidual
|
167 |
+
input_channel = 32
|
168 |
+
last_channel = 1280
|
169 |
+
width_mult = 1.0
|
170 |
+
round_nearest = 8
|
171 |
+
|
172 |
+
inverted_residual_setting = [
|
173 |
+
# t, c, n, s
|
174 |
+
[1, 16, 1, 1],
|
175 |
+
[6, 24, 2, 2],
|
176 |
+
[6, 32, 3, 2],
|
177 |
+
[6, 64, 4, 2],
|
178 |
+
#[6, 96, 3, 1],
|
179 |
+
#[6, 160, 3, 2],
|
180 |
+
#[6, 320, 1, 1],
|
181 |
+
]
|
182 |
+
|
183 |
+
# only check the first element, assuming user knows t,c,n,s are required
|
184 |
+
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
|
185 |
+
raise ValueError("inverted_residual_setting should be non-empty "
|
186 |
+
"or a 4-element list, got {}".format(inverted_residual_setting))
|
187 |
+
|
188 |
+
# building first layer
|
189 |
+
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
|
190 |
+
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
|
191 |
+
features = [ConvBNReLU(4, input_channel, stride=2)]
|
192 |
+
# building inverted residual blocks
|
193 |
+
for t, c, n, s in inverted_residual_setting:
|
194 |
+
output_channel = _make_divisible(c * width_mult, round_nearest)
|
195 |
+
for i in range(n):
|
196 |
+
stride = s if i == 0 else 1
|
197 |
+
features.append(block(input_channel, output_channel, stride, expand_ratio=t))
|
198 |
+
input_channel = output_channel
|
199 |
+
self.features = nn.Sequential(*features)
|
200 |
+
|
201 |
+
self.fpn_selected = [3, 6, 10]
|
202 |
+
# weight initialization
|
203 |
+
for m in self.modules():
|
204 |
+
if isinstance(m, nn.Conv2d):
|
205 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out')
|
206 |
+
if m.bias is not None:
|
207 |
+
nn.init.zeros_(m.bias)
|
208 |
+
elif isinstance(m, nn.BatchNorm2d):
|
209 |
+
nn.init.ones_(m.weight)
|
210 |
+
nn.init.zeros_(m.bias)
|
211 |
+
elif isinstance(m, nn.Linear):
|
212 |
+
nn.init.normal_(m.weight, 0, 0.01)
|
213 |
+
nn.init.zeros_(m.bias)
|
214 |
+
|
215 |
+
#if pretrained:
|
216 |
+
# self._load_pretrained_model()
|
217 |
+
|
218 |
+
def _forward_impl(self, x):
|
219 |
+
# This exists since TorchScript doesn't support inheritance, so the superclass method
|
220 |
+
# (this one) needs to have a name other than `forward` that can be accessed in a subclass
|
221 |
+
fpn_features = []
|
222 |
+
for i, f in enumerate(self.features):
|
223 |
+
if i > self.fpn_selected[-1]:
|
224 |
+
break
|
225 |
+
x = f(x)
|
226 |
+
if i in self.fpn_selected:
|
227 |
+
fpn_features.append(x)
|
228 |
+
|
229 |
+
c2, c3, c4 = fpn_features
|
230 |
+
return c2, c3, c4
|
231 |
+
|
232 |
+
|
233 |
+
def forward(self, x):
|
234 |
+
return self._forward_impl(x)
|
235 |
+
|
236 |
+
def _load_pretrained_model(self):
|
237 |
+
pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth')
|
238 |
+
model_dict = {}
|
239 |
+
state_dict = self.state_dict()
|
240 |
+
for k, v in pretrain_dict.items():
|
241 |
+
if k in state_dict:
|
242 |
+
model_dict[k] = v
|
243 |
+
state_dict.update(model_dict)
|
244 |
+
self.load_state_dict(state_dict)
|
245 |
+
|
246 |
+
|
247 |
+
class MobileV2_MLSD_Tiny(nn.Module):
|
248 |
+
def __init__(self):
|
249 |
+
super(MobileV2_MLSD_Tiny, self).__init__()
|
250 |
+
|
251 |
+
self.backbone = MobileNetV2(pretrained=True)
|
252 |
+
|
253 |
+
self.block12 = BlockTypeA(in_c1= 32, in_c2= 64,
|
254 |
+
out_c1= 64, out_c2=64)
|
255 |
+
self.block13 = BlockTypeB(128, 64)
|
256 |
+
|
257 |
+
self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64,
|
258 |
+
out_c1= 32, out_c2= 32)
|
259 |
+
self.block15 = BlockTypeB(64, 64)
|
260 |
+
|
261 |
+
self.block16 = BlockTypeC(64, 16)
|
262 |
+
|
263 |
+
def forward(self, x):
|
264 |
+
c2, c3, c4 = self.backbone(x)
|
265 |
+
|
266 |
+
x = self.block12(c3, c4)
|
267 |
+
x = self.block13(x)
|
268 |
+
x = self.block14(c2, x)
|
269 |
+
x = self.block15(x)
|
270 |
+
x = self.block16(x)
|
271 |
+
x = x[:, 7:, :, :]
|
272 |
+
#print(x.shape)
|
273 |
+
x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True)
|
274 |
+
|
275 |
+
return x
|
annotator/mlsd/utils.py
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
modified by lihaoweicv
|
3 |
+
pytorch version
|
4 |
+
'''
|
5 |
+
|
6 |
+
'''
|
7 |
+
M-LSD
|
8 |
+
Copyright 2021-present NAVER Corp.
|
9 |
+
Apache License v2.0
|
10 |
+
'''
|
11 |
+
|
12 |
+
import os
|
13 |
+
import numpy as np
|
14 |
+
import cv2
|
15 |
+
import torch
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
|
19 |
+
def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
|
20 |
+
'''
|
21 |
+
tpMap:
|
22 |
+
center: tpMap[1, 0, :, :]
|
23 |
+
displacement: tpMap[1, 1:5, :, :]
|
24 |
+
'''
|
25 |
+
b, c, h, w = tpMap.shape
|
26 |
+
assert b==1, 'only support bsize==1'
|
27 |
+
displacement = tpMap[:, 1:5, :, :][0]
|
28 |
+
center = tpMap[:, 0, :, :]
|
29 |
+
heat = torch.sigmoid(center)
|
30 |
+
hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2)
|
31 |
+
keep = (hmax == heat).float()
|
32 |
+
heat = heat * keep
|
33 |
+
heat = heat.reshape(-1, )
|
34 |
+
|
35 |
+
scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True)
|
36 |
+
yy = torch.floor_divide(indices, w).unsqueeze(-1)
|
37 |
+
xx = torch.fmod(indices, w).unsqueeze(-1)
|
38 |
+
ptss = torch.cat((yy, xx),dim=-1)
|
39 |
+
|
40 |
+
ptss = ptss.detach().cpu().numpy()
|
41 |
+
scores = scores.detach().cpu().numpy()
|
42 |
+
displacement = displacement.detach().cpu().numpy()
|
43 |
+
displacement = displacement.transpose((1,2,0))
|
44 |
+
return ptss, scores, displacement
|
45 |
+
|
46 |
+
|
47 |
+
def pred_lines(image, model,
|
48 |
+
input_shape=[512, 512],
|
49 |
+
score_thr=0.10,
|
50 |
+
dist_thr=20.0):
|
51 |
+
h, w, _ = image.shape
|
52 |
+
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]
|
53 |
+
|
54 |
+
resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
|
55 |
+
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
|
56 |
+
|
57 |
+
resized_image = resized_image.transpose((2,0,1))
|
58 |
+
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
|
59 |
+
batch_image = (batch_image / 127.5) - 1.0
|
60 |
+
|
61 |
+
batch_image = torch.from_numpy(batch_image).float().cuda()
|
62 |
+
outputs = model(batch_image)
|
63 |
+
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
|
64 |
+
start = vmap[:, :, :2]
|
65 |
+
end = vmap[:, :, 2:]
|
66 |
+
dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
|
67 |
+
|
68 |
+
segments_list = []
|
69 |
+
for center, score in zip(pts, pts_score):
|
70 |
+
y, x = center
|
71 |
+
distance = dist_map[y, x]
|
72 |
+
if score > score_thr and distance > dist_thr:
|
73 |
+
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
|
74 |
+
x_start = x + disp_x_start
|
75 |
+
y_start = y + disp_y_start
|
76 |
+
x_end = x + disp_x_end
|
77 |
+
y_end = y + disp_y_end
|
78 |
+
segments_list.append([x_start, y_start, x_end, y_end])
|
79 |
+
|
80 |
+
lines = 2 * np.array(segments_list) # 256 > 512
|
81 |
+
lines[:, 0] = lines[:, 0] * w_ratio
|
82 |
+
lines[:, 1] = lines[:, 1] * h_ratio
|
83 |
+
lines[:, 2] = lines[:, 2] * w_ratio
|
84 |
+
lines[:, 3] = lines[:, 3] * h_ratio
|
85 |
+
|
86 |
+
return lines
|
87 |
+
|
88 |
+
|
89 |
+
def pred_squares(image,
|
90 |
+
model,
|
91 |
+
input_shape=[512, 512],
|
92 |
+
params={'score': 0.06,
|
93 |
+
'outside_ratio': 0.28,
|
94 |
+
'inside_ratio': 0.45,
|
95 |
+
'w_overlap': 0.0,
|
96 |
+
'w_degree': 1.95,
|
97 |
+
'w_length': 0.0,
|
98 |
+
'w_area': 1.86,
|
99 |
+
'w_center': 0.14}):
|
100 |
+
'''
|
101 |
+
shape = [height, width]
|
102 |
+
'''
|
103 |
+
h, w, _ = image.shape
|
104 |
+
original_shape = [h, w]
|
105 |
+
|
106 |
+
resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
|
107 |
+
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
|
108 |
+
resized_image = resized_image.transpose((2, 0, 1))
|
109 |
+
batch_image = np.expand_dims(resized_image, axis=0).astype('float32')
|
110 |
+
batch_image = (batch_image / 127.5) - 1.0
|
111 |
+
|
112 |
+
batch_image = torch.from_numpy(batch_image).float().cuda()
|
113 |
+
outputs = model(batch_image)
|
114 |
+
|
115 |
+
pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3)
|
116 |
+
start = vmap[:, :, :2] # (x, y)
|
117 |
+
end = vmap[:, :, 2:] # (x, y)
|
118 |
+
dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1))
|
119 |
+
|
120 |
+
junc_list = []
|
121 |
+
segments_list = []
|
122 |
+
for junc, score in zip(pts, pts_score):
|
123 |
+
y, x = junc
|
124 |
+
distance = dist_map[y, x]
|
125 |
+
if score > params['score'] and distance > 20.0:
|
126 |
+
junc_list.append([x, y])
|
127 |
+
disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :]
|
128 |
+
d_arrow = 1.0
|
129 |
+
x_start = x + d_arrow * disp_x_start
|
130 |
+
y_start = y + d_arrow * disp_y_start
|
131 |
+
x_end = x + d_arrow * disp_x_end
|
132 |
+
y_end = y + d_arrow * disp_y_end
|
133 |
+
segments_list.append([x_start, y_start, x_end, y_end])
|
134 |
+
|
135 |
+
segments = np.array(segments_list)
|
136 |
+
|
137 |
+
####### post processing for squares
|
138 |
+
# 1. get unique lines
|
139 |
+
point = np.array([[0, 0]])
|
140 |
+
point = point[0]
|
141 |
+
start = segments[:, :2]
|
142 |
+
end = segments[:, 2:]
|
143 |
+
diff = start - end
|
144 |
+
a = diff[:, 1]
|
145 |
+
b = -diff[:, 0]
|
146 |
+
c = a * start[:, 0] + b * start[:, 1]
|
147 |
+
|
148 |
+
d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10)
|
149 |
+
theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi
|
150 |
+
theta[theta < 0.0] += 180
|
151 |
+
hough = np.concatenate([d[:, None], theta[:, None]], axis=-1)
|
152 |
+
|
153 |
+
d_quant = 1
|
154 |
+
theta_quant = 2
|
155 |
+
hough[:, 0] //= d_quant
|
156 |
+
hough[:, 1] //= theta_quant
|
157 |
+
_, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True)
|
158 |
+
|
159 |
+
acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32')
|
160 |
+
idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1
|
161 |
+
yx_indices = hough[indices, :].astype('int32')
|
162 |
+
acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts
|
163 |
+
idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices
|
164 |
+
|
165 |
+
acc_map_np = acc_map
|
166 |
+
# acc_map = acc_map[None, :, :, None]
|
167 |
+
#
|
168 |
+
# ### fast suppression using tensorflow op
|
169 |
+
# acc_map = tf.constant(acc_map, dtype=tf.float32)
|
170 |
+
# max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map)
|
171 |
+
# acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32)
|
172 |
+
# flatten_acc_map = tf.reshape(acc_map, [1, -1])
|
173 |
+
# topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts))
|
174 |
+
# _, h, w, _ = acc_map.shape
|
175 |
+
# y = tf.expand_dims(topk_indices // w, axis=-1)
|
176 |
+
# x = tf.expand_dims(topk_indices % w, axis=-1)
|
177 |
+
# yx = tf.concat([y, x], axis=-1)
|
178 |
+
|
179 |
+
### fast suppression using pytorch op
|
180 |
+
acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0)
|
181 |
+
_,_, h, w = acc_map.shape
|
182 |
+
max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2)
|
183 |
+
acc_map = acc_map * ( (acc_map == max_acc_map).float() )
|
184 |
+
flatten_acc_map = acc_map.reshape([-1, ])
|
185 |
+
|
186 |
+
scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True)
|
187 |
+
yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1)
|
188 |
+
xx = torch.fmod(indices, w).unsqueeze(-1)
|
189 |
+
yx = torch.cat((yy, xx), dim=-1)
|
190 |
+
|
191 |
+
yx = yx.detach().cpu().numpy()
|
192 |
+
|
193 |
+
topk_values = scores.detach().cpu().numpy()
|
194 |
+
indices = idx_map[yx[:, 0], yx[:, 1]]
|
195 |
+
basis = 5 // 2
|
196 |
+
|
197 |
+
merged_segments = []
|
198 |
+
for yx_pt, max_indice, value in zip(yx, indices, topk_values):
|
199 |
+
y, x = yx_pt
|
200 |
+
if max_indice == -1 or value == 0:
|
201 |
+
continue
|
202 |
+
segment_list = []
|
203 |
+
for y_offset in range(-basis, basis + 1):
|
204 |
+
for x_offset in range(-basis, basis + 1):
|
205 |
+
indice = idx_map[y + y_offset, x + x_offset]
|
206 |
+
cnt = int(acc_map_np[y + y_offset, x + x_offset])
|
207 |
+
if indice != -1:
|
208 |
+
segment_list.append(segments[indice])
|
209 |
+
if cnt > 1:
|
210 |
+
check_cnt = 1
|
211 |
+
current_hough = hough[indice]
|
212 |
+
for new_indice, new_hough in enumerate(hough):
|
213 |
+
if (current_hough == new_hough).all() and indice != new_indice:
|
214 |
+
segment_list.append(segments[new_indice])
|
215 |
+
check_cnt += 1
|
216 |
+
if check_cnt == cnt:
|
217 |
+
break
|
218 |
+
group_segments = np.array(segment_list).reshape([-1, 2])
|
219 |
+
sorted_group_segments = np.sort(group_segments, axis=0)
|
220 |
+
x_min, y_min = sorted_group_segments[0, :]
|
221 |
+
x_max, y_max = sorted_group_segments[-1, :]
|
222 |
+
|
223 |
+
deg = theta[max_indice]
|
224 |
+
if deg >= 90:
|
225 |
+
merged_segments.append([x_min, y_max, x_max, y_min])
|
226 |
+
else:
|
227 |
+
merged_segments.append([x_min, y_min, x_max, y_max])
|
228 |
+
|
229 |
+
# 2. get intersections
|
230 |
+
new_segments = np.array(merged_segments) # (x1, y1, x2, y2)
|
231 |
+
start = new_segments[:, :2] # (x1, y1)
|
232 |
+
end = new_segments[:, 2:] # (x2, y2)
|
233 |
+
new_centers = (start + end) / 2.0
|
234 |
+
diff = start - end
|
235 |
+
dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1))
|
236 |
+
|
237 |
+
# ax + by = c
|
238 |
+
a = diff[:, 1]
|
239 |
+
b = -diff[:, 0]
|
240 |
+
c = a * start[:, 0] + b * start[:, 1]
|
241 |
+
pre_det = a[:, None] * b[None, :]
|
242 |
+
det = pre_det - np.transpose(pre_det)
|
243 |
+
|
244 |
+
pre_inter_y = a[:, None] * c[None, :]
|
245 |
+
inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10)
|
246 |
+
pre_inter_x = c[:, None] * b[None, :]
|
247 |
+
inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10)
|
248 |
+
inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32')
|
249 |
+
|
250 |
+
# 3. get corner information
|
251 |
+
# 3.1 get distance
|
252 |
+
'''
|
253 |
+
dist_segments:
|
254 |
+
| dist(0), dist(1), dist(2), ...|
|
255 |
+
dist_inter_to_segment1:
|
256 |
+
| dist(inter,0), dist(inter,0), dist(inter,0), ... |
|
257 |
+
| dist(inter,1), dist(inter,1), dist(inter,1), ... |
|
258 |
+
...
|
259 |
+
dist_inter_to_semgnet2:
|
260 |
+
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
|
261 |
+
| dist(inter,0), dist(inter,1), dist(inter,2), ... |
|
262 |
+
...
|
263 |
+
'''
|
264 |
+
|
265 |
+
dist_inter_to_segment1_start = np.sqrt(
|
266 |
+
np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
267 |
+
dist_inter_to_segment1_end = np.sqrt(
|
268 |
+
np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
269 |
+
dist_inter_to_segment2_start = np.sqrt(
|
270 |
+
np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
271 |
+
dist_inter_to_segment2_end = np.sqrt(
|
272 |
+
np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1]
|
273 |
+
|
274 |
+
# sort ascending
|
275 |
+
dist_inter_to_segment1 = np.sort(
|
276 |
+
np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1),
|
277 |
+
axis=-1) # [n_batch, n_batch, 2]
|
278 |
+
dist_inter_to_segment2 = np.sort(
|
279 |
+
np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1),
|
280 |
+
axis=-1) # [n_batch, n_batch, 2]
|
281 |
+
|
282 |
+
# 3.2 get degree
|
283 |
+
inter_to_start = new_centers[:, None, :] - inter_pts
|
284 |
+
deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi
|
285 |
+
deg_inter_to_start[deg_inter_to_start < 0.0] += 360
|
286 |
+
inter_to_end = new_centers[None, :, :] - inter_pts
|
287 |
+
deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi
|
288 |
+
deg_inter_to_end[deg_inter_to_end < 0.0] += 360
|
289 |
+
|
290 |
+
'''
|
291 |
+
B -- G
|
292 |
+
| |
|
293 |
+
C -- R
|
294 |
+
B : blue / G: green / C: cyan / R: red
|
295 |
+
|
296 |
+
0 -- 1
|
297 |
+
| |
|
298 |
+
3 -- 2
|
299 |
+
'''
|
300 |
+
# rename variables
|
301 |
+
deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end
|
302 |
+
# sort deg ascending
|
303 |
+
deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1)
|
304 |
+
|
305 |
+
deg_diff_map = np.abs(deg1_map - deg2_map)
|
306 |
+
# we only consider the smallest degree of intersect
|
307 |
+
deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180]
|
308 |
+
|
309 |
+
# define available degree range
|
310 |
+
deg_range = [60, 120]
|
311 |
+
|
312 |
+
corner_dict = {corner_info: [] for corner_info in range(4)}
|
313 |
+
inter_points = []
|
314 |
+
for i in range(inter_pts.shape[0]):
|
315 |
+
for j in range(i + 1, inter_pts.shape[1]):
|
316 |
+
# i, j > line index, always i < j
|
317 |
+
x, y = inter_pts[i, j, :]
|
318 |
+
deg1, deg2 = deg_sort[i, j, :]
|
319 |
+
deg_diff = deg_diff_map[i, j]
|
320 |
+
|
321 |
+
check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1]
|
322 |
+
|
323 |
+
outside_ratio = params['outside_ratio'] # over ratio >>> drop it!
|
324 |
+
inside_ratio = params['inside_ratio'] # over ratio >>> drop it!
|
325 |
+
check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \
|
326 |
+
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \
|
327 |
+
(dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \
|
328 |
+
dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \
|
329 |
+
((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \
|
330 |
+
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \
|
331 |
+
(dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \
|
332 |
+
dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio))
|
333 |
+
|
334 |
+
if check_degree and check_distance:
|
335 |
+
corner_info = None
|
336 |
+
|
337 |
+
if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \
|
338 |
+
(deg2 >= 315 and deg1 >= 45 and deg1 <= 120):
|
339 |
+
corner_info, color_info = 0, 'blue'
|
340 |
+
elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225):
|
341 |
+
corner_info, color_info = 1, 'green'
|
342 |
+
elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315):
|
343 |
+
corner_info, color_info = 2, 'black'
|
344 |
+
elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \
|
345 |
+
(deg2 >= 315 and deg1 >= 225 and deg1 <= 315):
|
346 |
+
corner_info, color_info = 3, 'cyan'
|
347 |
+
else:
|
348 |
+
corner_info, color_info = 4, 'red' # we don't use it
|
349 |
+
continue
|
350 |
+
|
351 |
+
corner_dict[corner_info].append([x, y, i, j])
|
352 |
+
inter_points.append([x, y])
|
353 |
+
|
354 |
+
square_list = []
|
355 |
+
connect_list = []
|
356 |
+
segments_list = []
|
357 |
+
for corner0 in corner_dict[0]:
|
358 |
+
for corner1 in corner_dict[1]:
|
359 |
+
connect01 = False
|
360 |
+
for corner0_line in corner0[2:]:
|
361 |
+
if corner0_line in corner1[2:]:
|
362 |
+
connect01 = True
|
363 |
+
break
|
364 |
+
if connect01:
|
365 |
+
for corner2 in corner_dict[2]:
|
366 |
+
connect12 = False
|
367 |
+
for corner1_line in corner1[2:]:
|
368 |
+
if corner1_line in corner2[2:]:
|
369 |
+
connect12 = True
|
370 |
+
break
|
371 |
+
if connect12:
|
372 |
+
for corner3 in corner_dict[3]:
|
373 |
+
connect23 = False
|
374 |
+
for corner2_line in corner2[2:]:
|
375 |
+
if corner2_line in corner3[2:]:
|
376 |
+
connect23 = True
|
377 |
+
break
|
378 |
+
if connect23:
|
379 |
+
for corner3_line in corner3[2:]:
|
380 |
+
if corner3_line in corner0[2:]:
|
381 |
+
# SQUARE!!!
|
382 |
+
'''
|
383 |
+
0 -- 1
|
384 |
+
| |
|
385 |
+
3 -- 2
|
386 |
+
square_list:
|
387 |
+
order: 0 > 1 > 2 > 3
|
388 |
+
| x0, y0, x1, y1, x2, y2, x3, y3 |
|
389 |
+
| x0, y0, x1, y1, x2, y2, x3, y3 |
|
390 |
+
...
|
391 |
+
connect_list:
|
392 |
+
order: 01 > 12 > 23 > 30
|
393 |
+
| line_idx01, line_idx12, line_idx23, line_idx30 |
|
394 |
+
| line_idx01, line_idx12, line_idx23, line_idx30 |
|
395 |
+
...
|
396 |
+
segments_list:
|
397 |
+
order: 0 > 1 > 2 > 3
|
398 |
+
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
|
399 |
+
| line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j |
|
400 |
+
...
|
401 |
+
'''
|
402 |
+
square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2])
|
403 |
+
connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line])
|
404 |
+
segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:])
|
405 |
+
|
406 |
+
def check_outside_inside(segments_info, connect_idx):
|
407 |
+
# return 'outside or inside', min distance, cover_param, peri_param
|
408 |
+
if connect_idx == segments_info[0]:
|
409 |
+
check_dist_mat = dist_inter_to_segment1
|
410 |
+
else:
|
411 |
+
check_dist_mat = dist_inter_to_segment2
|
412 |
+
|
413 |
+
i, j = segments_info
|
414 |
+
min_dist, max_dist = check_dist_mat[i, j, :]
|
415 |
+
connect_dist = dist_segments[connect_idx]
|
416 |
+
if max_dist > connect_dist:
|
417 |
+
return 'outside', min_dist, 0, 1
|
418 |
+
else:
|
419 |
+
return 'inside', min_dist, -1, -1
|
420 |
+
|
421 |
+
top_square = None
|
422 |
+
|
423 |
+
try:
|
424 |
+
map_size = input_shape[0] / 2
|
425 |
+
squares = np.array(square_list).reshape([-1, 4, 2])
|
426 |
+
score_array = []
|
427 |
+
connect_array = np.array(connect_list)
|
428 |
+
segments_array = np.array(segments_list).reshape([-1, 4, 2])
|
429 |
+
|
430 |
+
# get degree of corners:
|
431 |
+
squares_rollup = np.roll(squares, 1, axis=1)
|
432 |
+
squares_rolldown = np.roll(squares, -1, axis=1)
|
433 |
+
vec1 = squares_rollup - squares
|
434 |
+
normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10)
|
435 |
+
vec2 = squares_rolldown - squares
|
436 |
+
normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10)
|
437 |
+
inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4]
|
438 |
+
squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4]
|
439 |
+
|
440 |
+
# get square score
|
441 |
+
overlap_scores = []
|
442 |
+
degree_scores = []
|
443 |
+
length_scores = []
|
444 |
+
|
445 |
+
for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree):
|
446 |
+
'''
|
447 |
+
0 -- 1
|
448 |
+
| |
|
449 |
+
3 -- 2
|
450 |
+
|
451 |
+
# segments: [4, 2]
|
452 |
+
# connects: [4]
|
453 |
+
'''
|
454 |
+
|
455 |
+
###################################### OVERLAP SCORES
|
456 |
+
cover = 0
|
457 |
+
perimeter = 0
|
458 |
+
# check 0 > 1 > 2 > 3
|
459 |
+
square_length = []
|
460 |
+
|
461 |
+
for start_idx in range(4):
|
462 |
+
end_idx = (start_idx + 1) % 4
|
463 |
+
|
464 |
+
connect_idx = connects[start_idx] # segment idx of segment01
|
465 |
+
start_segments = segments[start_idx]
|
466 |
+
end_segments = segments[end_idx]
|
467 |
+
|
468 |
+
start_point = square[start_idx]
|
469 |
+
end_point = square[end_idx]
|
470 |
+
|
471 |
+
# check whether outside or inside
|
472 |
+
start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments,
|
473 |
+
connect_idx)
|
474 |
+
end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx)
|
475 |
+
|
476 |
+
cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min
|
477 |
+
perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min
|
478 |
+
|
479 |
+
square_length.append(
|
480 |
+
dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min)
|
481 |
+
|
482 |
+
overlap_scores.append(cover / perimeter)
|
483 |
+
######################################
|
484 |
+
###################################### DEGREE SCORES
|
485 |
+
'''
|
486 |
+
deg0 vs deg2
|
487 |
+
deg1 vs deg3
|
488 |
+
'''
|
489 |
+
deg0, deg1, deg2, deg3 = degree
|
490 |
+
deg_ratio1 = deg0 / deg2
|
491 |
+
if deg_ratio1 > 1.0:
|
492 |
+
deg_ratio1 = 1 / deg_ratio1
|
493 |
+
deg_ratio2 = deg1 / deg3
|
494 |
+
if deg_ratio2 > 1.0:
|
495 |
+
deg_ratio2 = 1 / deg_ratio2
|
496 |
+
degree_scores.append((deg_ratio1 + deg_ratio2) / 2)
|
497 |
+
######################################
|
498 |
+
###################################### LENGTH SCORES
|
499 |
+
'''
|
500 |
+
len0 vs len2
|
501 |
+
len1 vs len3
|
502 |
+
'''
|
503 |
+
len0, len1, len2, len3 = square_length
|
504 |
+
len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0
|
505 |
+
len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1
|
506 |
+
length_scores.append((len_ratio1 + len_ratio2) / 2)
|
507 |
+
|
508 |
+
######################################
|
509 |
+
|
510 |
+
overlap_scores = np.array(overlap_scores)
|
511 |
+
overlap_scores /= np.max(overlap_scores)
|
512 |
+
|
513 |
+
degree_scores = np.array(degree_scores)
|
514 |
+
# degree_scores /= np.max(degree_scores)
|
515 |
+
|
516 |
+
length_scores = np.array(length_scores)
|
517 |
+
|
518 |
+
###################################### AREA SCORES
|
519 |
+
area_scores = np.reshape(squares, [-1, 4, 2])
|
520 |
+
area_x = area_scores[:, :, 0]
|
521 |
+
area_y = area_scores[:, :, 1]
|
522 |
+
correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0]
|
523 |
+
area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1)
|
524 |
+
area_scores = 0.5 * np.abs(area_scores + correction)
|
525 |
+
area_scores /= (map_size * map_size) # np.max(area_scores)
|
526 |
+
######################################
|
527 |
+
|
528 |
+
###################################### CENTER SCORES
|
529 |
+
centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2]
|
530 |
+
# squares: [n, 4, 2]
|
531 |
+
square_centers = np.mean(squares, axis=1) # [n, 2]
|
532 |
+
center2center = np.sqrt(np.sum((centers - square_centers) ** 2))
|
533 |
+
center_scores = center2center / (map_size / np.sqrt(2.0))
|
534 |
+
|
535 |
+
'''
|
536 |
+
score_w = [overlap, degree, area, center, length]
|
537 |
+
'''
|
538 |
+
score_w = [0.0, 1.0, 10.0, 0.5, 1.0]
|
539 |
+
score_array = params['w_overlap'] * overlap_scores \
|
540 |
+
+ params['w_degree'] * degree_scores \
|
541 |
+
+ params['w_area'] * area_scores \
|
542 |
+
- params['w_center'] * center_scores \
|
543 |
+
+ params['w_length'] * length_scores
|
544 |
+
|
545 |
+
best_square = []
|
546 |
+
|
547 |
+
sorted_idx = np.argsort(score_array)[::-1]
|
548 |
+
score_array = score_array[sorted_idx]
|
549 |
+
squares = squares[sorted_idx]
|
550 |
+
|
551 |
+
except Exception as e:
|
552 |
+
pass
|
553 |
+
|
554 |
+
'''return list
|
555 |
+
merged_lines, squares, scores
|
556 |
+
'''
|
557 |
+
|
558 |
+
try:
|
559 |
+
new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1]
|
560 |
+
new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0]
|
561 |
+
new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1]
|
562 |
+
new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0]
|
563 |
+
except:
|
564 |
+
new_segments = []
|
565 |
+
|
566 |
+
try:
|
567 |
+
squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1]
|
568 |
+
squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0]
|
569 |
+
except:
|
570 |
+
squares = []
|
571 |
+
score_array = []
|
572 |
+
|
573 |
+
try:
|
574 |
+
inter_points = np.array(inter_points)
|
575 |
+
inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1]
|
576 |
+
inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0]
|
577 |
+
except:
|
578 |
+
inter_points = []
|
579 |
+
|
580 |
+
return new_segments, squares, score_array, inter_points
|
annotator/normalbae/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Caroline Chan
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
annotator/normalbae/__init__.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Estimating and Exploiting the Aleatoric Uncertainty in Surface Normal Estimation
|
2 |
+
# https://github.com/baegwangbin/surface_normal_uncertainty
|
3 |
+
|
4 |
+
import os
|
5 |
+
import types
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from einops import rearrange
|
10 |
+
from .models.NNET import NNET
|
11 |
+
from .utils import utils
|
12 |
+
from annotator.util import annotator_ckpts_path
|
13 |
+
import torchvision.transforms as transforms
|
14 |
+
|
15 |
+
|
16 |
+
class NormalBaeDetector:
|
17 |
+
def __init__(self):
|
18 |
+
remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/scannet.pt"
|
19 |
+
modelpath = os.path.join(annotator_ckpts_path, "scannet.pt")
|
20 |
+
if not os.path.exists(modelpath):
|
21 |
+
from basicsr.utils.download_util import load_file_from_url
|
22 |
+
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
23 |
+
args = types.SimpleNamespace()
|
24 |
+
args.mode = 'client'
|
25 |
+
args.architecture = 'BN'
|
26 |
+
args.pretrained = 'scannet'
|
27 |
+
args.sampling_ratio = 0.4
|
28 |
+
args.importance_ratio = 0.7
|
29 |
+
model = NNET(args)
|
30 |
+
model = utils.load_checkpoint(modelpath, model)
|
31 |
+
model = model.cuda()
|
32 |
+
model.eval()
|
33 |
+
self.model = model
|
34 |
+
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
35 |
+
|
36 |
+
def __call__(self, input_image):
|
37 |
+
assert input_image.ndim == 3
|
38 |
+
image_normal = input_image
|
39 |
+
with torch.no_grad():
|
40 |
+
image_normal = torch.from_numpy(image_normal).float().cuda()
|
41 |
+
image_normal = image_normal / 255.0
|
42 |
+
image_normal = rearrange(image_normal, 'h w c -> 1 c h w')
|
43 |
+
image_normal = self.norm(image_normal)
|
44 |
+
|
45 |
+
normal = self.model(image_normal)
|
46 |
+
normal = normal[0][-1][:, :3]
|
47 |
+
# d = torch.sum(normal ** 2.0, dim=1, keepdim=True) ** 0.5
|
48 |
+
# d = torch.maximum(d, torch.ones_like(d) * 1e-5)
|
49 |
+
# normal /= d
|
50 |
+
normal = ((normal + 1) * 0.5).clip(0, 1)
|
51 |
+
|
52 |
+
normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy()
|
53 |
+
normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
|
54 |
+
|
55 |
+
return normal_image
|
annotator/normalbae/models/NNET.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .submodules.encoder import Encoder
|
6 |
+
from .submodules.decoder import Decoder
|
7 |
+
|
8 |
+
|
9 |
+
class NNET(nn.Module):
|
10 |
+
def __init__(self, args):
|
11 |
+
super(NNET, self).__init__()
|
12 |
+
self.encoder = Encoder()
|
13 |
+
self.decoder = Decoder(args)
|
14 |
+
|
15 |
+
def get_1x_lr_params(self): # lr/10 learning rate
|
16 |
+
return self.encoder.parameters()
|
17 |
+
|
18 |
+
def get_10x_lr_params(self): # lr learning rate
|
19 |
+
return self.decoder.parameters()
|
20 |
+
|
21 |
+
def forward(self, img, **kwargs):
|
22 |
+
return self.decoder(self.encoder(img), **kwargs)
|
annotator/normalbae/models/baseline.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .submodules.submodules import UpSampleBN, norm_normalize
|
6 |
+
|
7 |
+
|
8 |
+
# This is the baseline encoder-decoder we used in the ablation study
|
9 |
+
class NNET(nn.Module):
|
10 |
+
def __init__(self, args=None):
|
11 |
+
super(NNET, self).__init__()
|
12 |
+
self.encoder = Encoder()
|
13 |
+
self.decoder = Decoder(num_classes=4)
|
14 |
+
|
15 |
+
def forward(self, x, **kwargs):
|
16 |
+
out = self.decoder(self.encoder(x), **kwargs)
|
17 |
+
|
18 |
+
# Bilinearly upsample the output to match the input resolution
|
19 |
+
up_out = F.interpolate(out, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False)
|
20 |
+
|
21 |
+
# L2-normalize the first three channels / ensure positive value for concentration parameters (kappa)
|
22 |
+
up_out = norm_normalize(up_out)
|
23 |
+
return up_out
|
24 |
+
|
25 |
+
def get_1x_lr_params(self): # lr/10 learning rate
|
26 |
+
return self.encoder.parameters()
|
27 |
+
|
28 |
+
def get_10x_lr_params(self): # lr learning rate
|
29 |
+
modules = [self.decoder]
|
30 |
+
for m in modules:
|
31 |
+
yield from m.parameters()
|
32 |
+
|
33 |
+
|
34 |
+
# Encoder
|
35 |
+
class Encoder(nn.Module):
|
36 |
+
def __init__(self):
|
37 |
+
super(Encoder, self).__init__()
|
38 |
+
|
39 |
+
basemodel_name = 'tf_efficientnet_b5_ap'
|
40 |
+
basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True)
|
41 |
+
|
42 |
+
# Remove last layer
|
43 |
+
basemodel.global_pool = nn.Identity()
|
44 |
+
basemodel.classifier = nn.Identity()
|
45 |
+
|
46 |
+
self.original_model = basemodel
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
features = [x]
|
50 |
+
for k, v in self.original_model._modules.items():
|
51 |
+
if (k == 'blocks'):
|
52 |
+
for ki, vi in v._modules.items():
|
53 |
+
features.append(vi(features[-1]))
|
54 |
+
else:
|
55 |
+
features.append(v(features[-1]))
|
56 |
+
return features
|
57 |
+
|
58 |
+
|
59 |
+
# Decoder (no pixel-wise MLP, no uncertainty-guided sampling)
|
60 |
+
class Decoder(nn.Module):
|
61 |
+
def __init__(self, num_classes=4):
|
62 |
+
super(Decoder, self).__init__()
|
63 |
+
self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
|
64 |
+
self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
|
65 |
+
self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
|
66 |
+
self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
|
67 |
+
self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
|
68 |
+
self.conv3 = nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1)
|
69 |
+
|
70 |
+
def forward(self, features):
|
71 |
+
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
|
72 |
+
x_d0 = self.conv2(x_block4)
|
73 |
+
x_d1 = self.up1(x_d0, x_block3)
|
74 |
+
x_d2 = self.up2(x_d1, x_block2)
|
75 |
+
x_d3 = self.up3(x_d2, x_block1)
|
76 |
+
x_d4 = self.up4(x_d3, x_block0)
|
77 |
+
out = self.conv3(x_d4)
|
78 |
+
return out
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == '__main__':
|
82 |
+
model = Baseline()
|
83 |
+
x = torch.rand(2, 3, 480, 640)
|
84 |
+
out = model(x)
|
85 |
+
print(out.shape)
|
annotator/normalbae/models/submodules/decoder.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .submodules import UpSampleBN, UpSampleGN, norm_normalize, sample_points
|
5 |
+
|
6 |
+
|
7 |
+
class Decoder(nn.Module):
|
8 |
+
def __init__(self, args):
|
9 |
+
super(Decoder, self).__init__()
|
10 |
+
|
11 |
+
# hyper-parameter for sampling
|
12 |
+
self.sampling_ratio = args.sampling_ratio
|
13 |
+
self.importance_ratio = args.importance_ratio
|
14 |
+
|
15 |
+
# feature-map
|
16 |
+
self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0)
|
17 |
+
if args.architecture == 'BN':
|
18 |
+
self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024)
|
19 |
+
self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512)
|
20 |
+
self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256)
|
21 |
+
self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128)
|
22 |
+
|
23 |
+
elif args.architecture == 'GN':
|
24 |
+
self.up1 = UpSampleGN(skip_input=2048 + 176, output_features=1024)
|
25 |
+
self.up2 = UpSampleGN(skip_input=1024 + 64, output_features=512)
|
26 |
+
self.up3 = UpSampleGN(skip_input=512 + 40, output_features=256)
|
27 |
+
self.up4 = UpSampleGN(skip_input=256 + 24, output_features=128)
|
28 |
+
|
29 |
+
else:
|
30 |
+
raise Exception('invalid architecture')
|
31 |
+
|
32 |
+
# produces 1/8 res output
|
33 |
+
self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
34 |
+
|
35 |
+
# produces 1/4 res output
|
36 |
+
self.out_conv_res4 = nn.Sequential(
|
37 |
+
nn.Conv1d(512 + 4, 128, kernel_size=1), nn.ReLU(),
|
38 |
+
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
39 |
+
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
40 |
+
nn.Conv1d(128, 4, kernel_size=1),
|
41 |
+
)
|
42 |
+
|
43 |
+
# produces 1/2 res output
|
44 |
+
self.out_conv_res2 = nn.Sequential(
|
45 |
+
nn.Conv1d(256 + 4, 128, kernel_size=1), nn.ReLU(),
|
46 |
+
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
47 |
+
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
48 |
+
nn.Conv1d(128, 4, kernel_size=1),
|
49 |
+
)
|
50 |
+
|
51 |
+
# produces 1/1 res output
|
52 |
+
self.out_conv_res1 = nn.Sequential(
|
53 |
+
nn.Conv1d(128 + 4, 128, kernel_size=1), nn.ReLU(),
|
54 |
+
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
55 |
+
nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(),
|
56 |
+
nn.Conv1d(128, 4, kernel_size=1),
|
57 |
+
)
|
58 |
+
|
59 |
+
def forward(self, features, gt_norm_mask=None, mode='test'):
|
60 |
+
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11]
|
61 |
+
|
62 |
+
# generate feature-map
|
63 |
+
|
64 |
+
x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res
|
65 |
+
x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res
|
66 |
+
x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res
|
67 |
+
x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res
|
68 |
+
x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res
|
69 |
+
|
70 |
+
# 1/8 res output
|
71 |
+
out_res8 = self.out_conv_res8(x_d2) # out_res8: [2, 4, 60, 80] 1/8 res output
|
72 |
+
out_res8 = norm_normalize(out_res8) # out_res8: [2, 4, 60, 80] 1/8 res output
|
73 |
+
|
74 |
+
################################################################################################################
|
75 |
+
# out_res4
|
76 |
+
################################################################################################################
|
77 |
+
|
78 |
+
if mode == 'train':
|
79 |
+
# upsampling ... out_res8: [2, 4, 60, 80] -> out_res8_res4: [2, 4, 120, 160]
|
80 |
+
out_res8_res4 = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
|
81 |
+
B, _, H, W = out_res8_res4.shape
|
82 |
+
|
83 |
+
# samples: [B, 1, N, 2]
|
84 |
+
point_coords_res4, rows_int, cols_int = sample_points(out_res8_res4.detach(), gt_norm_mask,
|
85 |
+
sampling_ratio=self.sampling_ratio,
|
86 |
+
beta=self.importance_ratio)
|
87 |
+
|
88 |
+
# output (needed for evaluation / visualization)
|
89 |
+
out_res4 = out_res8_res4
|
90 |
+
|
91 |
+
# grid_sample feature-map
|
92 |
+
feat_res4 = F.grid_sample(x_d2, point_coords_res4, mode='bilinear', align_corners=True) # (B, 512, 1, N)
|
93 |
+
init_pred = F.grid_sample(out_res8, point_coords_res4, mode='bilinear', align_corners=True) # (B, 4, 1, N)
|
94 |
+
feat_res4 = torch.cat([feat_res4, init_pred], dim=1) # (B, 512+4, 1, N)
|
95 |
+
|
96 |
+
# prediction (needed to compute loss)
|
97 |
+
samples_pred_res4 = self.out_conv_res4(feat_res4[:, :, 0, :]) # (B, 4, N)
|
98 |
+
samples_pred_res4 = norm_normalize(samples_pred_res4) # (B, 4, N) - normalized
|
99 |
+
|
100 |
+
for i in range(B):
|
101 |
+
out_res4[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res4[i, :, :]
|
102 |
+
|
103 |
+
else:
|
104 |
+
# grid_sample feature-map
|
105 |
+
feat_map = F.interpolate(x_d2, scale_factor=2, mode='bilinear', align_corners=True)
|
106 |
+
init_pred = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True)
|
107 |
+
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
|
108 |
+
B, _, H, W = feat_map.shape
|
109 |
+
|
110 |
+
# try all pixels
|
111 |
+
out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N)
|
112 |
+
out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized
|
113 |
+
out_res4 = out_res4.view(B, 4, H, W)
|
114 |
+
samples_pred_res4 = point_coords_res4 = None
|
115 |
+
|
116 |
+
################################################################################################################
|
117 |
+
# out_res2
|
118 |
+
################################################################################################################
|
119 |
+
|
120 |
+
if mode == 'train':
|
121 |
+
|
122 |
+
# upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
|
123 |
+
out_res4_res2 = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
|
124 |
+
B, _, H, W = out_res4_res2.shape
|
125 |
+
|
126 |
+
# samples: [B, 1, N, 2]
|
127 |
+
point_coords_res2, rows_int, cols_int = sample_points(out_res4_res2.detach(), gt_norm_mask,
|
128 |
+
sampling_ratio=self.sampling_ratio,
|
129 |
+
beta=self.importance_ratio)
|
130 |
+
|
131 |
+
# output (needed for evaluation / visualization)
|
132 |
+
out_res2 = out_res4_res2
|
133 |
+
|
134 |
+
# grid_sample feature-map
|
135 |
+
feat_res2 = F.grid_sample(x_d3, point_coords_res2, mode='bilinear', align_corners=True) # (B, 256, 1, N)
|
136 |
+
init_pred = F.grid_sample(out_res4, point_coords_res2, mode='bilinear', align_corners=True) # (B, 4, 1, N)
|
137 |
+
feat_res2 = torch.cat([feat_res2, init_pred], dim=1) # (B, 256+4, 1, N)
|
138 |
+
|
139 |
+
# prediction (needed to compute loss)
|
140 |
+
samples_pred_res2 = self.out_conv_res2(feat_res2[:, :, 0, :]) # (B, 4, N)
|
141 |
+
samples_pred_res2 = norm_normalize(samples_pred_res2) # (B, 4, N) - normalized
|
142 |
+
|
143 |
+
for i in range(B):
|
144 |
+
out_res2[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res2[i, :, :]
|
145 |
+
|
146 |
+
else:
|
147 |
+
# grid_sample feature-map
|
148 |
+
feat_map = F.interpolate(x_d3, scale_factor=2, mode='bilinear', align_corners=True)
|
149 |
+
init_pred = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True)
|
150 |
+
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
|
151 |
+
B, _, H, W = feat_map.shape
|
152 |
+
|
153 |
+
out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N)
|
154 |
+
out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized
|
155 |
+
out_res2 = out_res2.view(B, 4, H, W)
|
156 |
+
samples_pred_res2 = point_coords_res2 = None
|
157 |
+
|
158 |
+
################################################################################################################
|
159 |
+
# out_res1
|
160 |
+
################################################################################################################
|
161 |
+
|
162 |
+
if mode == 'train':
|
163 |
+
# upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320]
|
164 |
+
out_res2_res1 = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
|
165 |
+
B, _, H, W = out_res2_res1.shape
|
166 |
+
|
167 |
+
# samples: [B, 1, N, 2]
|
168 |
+
point_coords_res1, rows_int, cols_int = sample_points(out_res2_res1.detach(), gt_norm_mask,
|
169 |
+
sampling_ratio=self.sampling_ratio,
|
170 |
+
beta=self.importance_ratio)
|
171 |
+
|
172 |
+
# output (needed for evaluation / visualization)
|
173 |
+
out_res1 = out_res2_res1
|
174 |
+
|
175 |
+
# grid_sample feature-map
|
176 |
+
feat_res1 = F.grid_sample(x_d4, point_coords_res1, mode='bilinear', align_corners=True) # (B, 128, 1, N)
|
177 |
+
init_pred = F.grid_sample(out_res2, point_coords_res1, mode='bilinear', align_corners=True) # (B, 4, 1, N)
|
178 |
+
feat_res1 = torch.cat([feat_res1, init_pred], dim=1) # (B, 128+4, 1, N)
|
179 |
+
|
180 |
+
# prediction (needed to compute loss)
|
181 |
+
samples_pred_res1 = self.out_conv_res1(feat_res1[:, :, 0, :]) # (B, 4, N)
|
182 |
+
samples_pred_res1 = norm_normalize(samples_pred_res1) # (B, 4, N) - normalized
|
183 |
+
|
184 |
+
for i in range(B):
|
185 |
+
out_res1[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res1[i, :, :]
|
186 |
+
|
187 |
+
else:
|
188 |
+
# grid_sample feature-map
|
189 |
+
feat_map = F.interpolate(x_d4, scale_factor=2, mode='bilinear', align_corners=True)
|
190 |
+
init_pred = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True)
|
191 |
+
feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W)
|
192 |
+
B, _, H, W = feat_map.shape
|
193 |
+
|
194 |
+
out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N)
|
195 |
+
out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized
|
196 |
+
out_res1 = out_res1.view(B, 4, H, W)
|
197 |
+
samples_pred_res1 = point_coords_res1 = None
|
198 |
+
|
199 |
+
return [out_res8, out_res4, out_res2, out_res1], \
|
200 |
+
[out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1], \
|
201 |
+
[None, point_coords_res4, point_coords_res2, point_coords_res1]
|
202 |
+
|
annotator/normalbae/models/submodules/efficientnet_repo/BENCHMARK.md
ADDED
@@ -0,0 +1,555 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Performance Benchmarks
|
2 |
+
|
3 |
+
All benchmarks run as per:
|
4 |
+
|
5 |
+
```
|
6 |
+
python onnx_export.py --model mobilenetv3_100 ./mobilenetv3_100.onnx
|
7 |
+
python onnx_optimize.py ./mobilenetv3_100.onnx --output mobilenetv3_100-opt.onnx
|
8 |
+
python onnx_to_caffe.py ./mobilenetv3_100.onnx --c2-prefix mobilenetv3
|
9 |
+
python onnx_to_caffe.py ./mobilenetv3_100-opt.onnx --c2-prefix mobilenetv3-opt
|
10 |
+
python caffe2_benchmark.py --c2-init ./mobilenetv3.init.pb --c2-predict ./mobilenetv3.predict.pb
|
11 |
+
python caffe2_benchmark.py --c2-init ./mobilenetv3-opt.init.pb --c2-predict ./mobilenetv3-opt.predict.pb
|
12 |
+
```
|
13 |
+
|
14 |
+
## EfficientNet-B0
|
15 |
+
|
16 |
+
### Unoptimized
|
17 |
+
```
|
18 |
+
Main run finished. Milliseconds per iter: 49.2862. Iters per second: 20.2897
|
19 |
+
Time per operator type:
|
20 |
+
29.7378 ms. 60.5145%. Conv
|
21 |
+
12.1785 ms. 24.7824%. Sigmoid
|
22 |
+
3.62811 ms. 7.38297%. SpatialBN
|
23 |
+
2.98444 ms. 6.07314%. Mul
|
24 |
+
0.326902 ms. 0.665225%. AveragePool
|
25 |
+
0.197317 ms. 0.401528%. FC
|
26 |
+
0.0852877 ms. 0.173555%. Add
|
27 |
+
0.0032607 ms. 0.00663532%. Squeeze
|
28 |
+
49.1416 ms in Total
|
29 |
+
FLOP per operator type:
|
30 |
+
0.76907 GFLOP. 95.2696%. Conv
|
31 |
+
0.0269508 GFLOP. 3.33857%. SpatialBN
|
32 |
+
0.00846444 GFLOP. 1.04855%. Mul
|
33 |
+
0.002561 GFLOP. 0.317248%. FC
|
34 |
+
0.000210112 GFLOP. 0.0260279%. Add
|
35 |
+
0.807256 GFLOP in Total
|
36 |
+
Feature Memory Read per operator type:
|
37 |
+
58.5253 MB. 43.0891%. Mul
|
38 |
+
43.2015 MB. 31.807%. Conv
|
39 |
+
27.2869 MB. 20.0899%. SpatialBN
|
40 |
+
5.12912 MB. 3.77631%. FC
|
41 |
+
1.6809 MB. 1.23756%. Add
|
42 |
+
135.824 MB in Total
|
43 |
+
Feature Memory Written per operator type:
|
44 |
+
33.8578 MB. 38.1965%. Mul
|
45 |
+
26.9881 MB. 30.4465%. Conv
|
46 |
+
26.9508 MB. 30.4044%. SpatialBN
|
47 |
+
0.840448 MB. 0.948147%. Add
|
48 |
+
0.004 MB. 0.00451258%. FC
|
49 |
+
88.6412 MB in Total
|
50 |
+
Parameter Memory per operator type:
|
51 |
+
15.8248 MB. 74.9391%. Conv
|
52 |
+
5.124 MB. 24.265%. FC
|
53 |
+
0.168064 MB. 0.795877%. SpatialBN
|
54 |
+
0 MB. 0%. Add
|
55 |
+
0 MB. 0%. Mul
|
56 |
+
21.1168 MB in Total
|
57 |
+
```
|
58 |
+
### Optimized
|
59 |
+
```
|
60 |
+
Main run finished. Milliseconds per iter: 46.0838. Iters per second: 21.6996
|
61 |
+
Time per operator type:
|
62 |
+
29.776 ms. 65.002%. Conv
|
63 |
+
12.2803 ms. 26.8084%. Sigmoid
|
64 |
+
3.15073 ms. 6.87815%. Mul
|
65 |
+
0.328651 ms. 0.717456%. AveragePool
|
66 |
+
0.186237 ms. 0.406563%. FC
|
67 |
+
0.0832429 ms. 0.181722%. Add
|
68 |
+
0.0026184 ms. 0.00571606%. Squeeze
|
69 |
+
45.8078 ms in Total
|
70 |
+
FLOP per operator type:
|
71 |
+
0.76907 GFLOP. 98.5601%. Conv
|
72 |
+
0.00846444 GFLOP. 1.08476%. Mul
|
73 |
+
0.002561 GFLOP. 0.328205%. FC
|
74 |
+
0.000210112 GFLOP. 0.0269269%. Add
|
75 |
+
0.780305 GFLOP in Total
|
76 |
+
Feature Memory Read per operator type:
|
77 |
+
58.5253 MB. 53.8803%. Mul
|
78 |
+
43.2855 MB. 39.8501%. Conv
|
79 |
+
5.12912 MB. 4.72204%. FC
|
80 |
+
1.6809 MB. 1.54749%. Add
|
81 |
+
108.621 MB in Total
|
82 |
+
Feature Memory Written per operator type:
|
83 |
+
33.8578 MB. 54.8834%. Mul
|
84 |
+
26.9881 MB. 43.7477%. Conv
|
85 |
+
0.840448 MB. 1.36237%. Add
|
86 |
+
0.004 MB. 0.00648399%. FC
|
87 |
+
61.6904 MB in Total
|
88 |
+
Parameter Memory per operator type:
|
89 |
+
15.8248 MB. 75.5403%. Conv
|
90 |
+
5.124 MB. 24.4597%. FC
|
91 |
+
0 MB. 0%. Add
|
92 |
+
0 MB. 0%. Mul
|
93 |
+
20.9488 MB in Total
|
94 |
+
```
|
95 |
+
|
96 |
+
## EfficientNet-B1
|
97 |
+
### Optimized
|
98 |
+
```
|
99 |
+
Main run finished. Milliseconds per iter: 71.8102. Iters per second: 13.9256
|
100 |
+
Time per operator type:
|
101 |
+
45.7915 ms. 66.3206%. Conv
|
102 |
+
17.8718 ms. 25.8841%. Sigmoid
|
103 |
+
4.44132 ms. 6.43244%. Mul
|
104 |
+
0.51001 ms. 0.738658%. AveragePool
|
105 |
+
0.233283 ms. 0.337868%. Add
|
106 |
+
0.194986 ms. 0.282402%. FC
|
107 |
+
0.00268255 ms. 0.00388519%. Squeeze
|
108 |
+
69.0456 ms in Total
|
109 |
+
FLOP per operator type:
|
110 |
+
1.37105 GFLOP. 98.7673%. Conv
|
111 |
+
0.0138759 GFLOP. 0.99959%. Mul
|
112 |
+
0.002561 GFLOP. 0.184489%. FC
|
113 |
+
0.000674432 GFLOP. 0.0485847%. Add
|
114 |
+
1.38816 GFLOP in Total
|
115 |
+
Feature Memory Read per operator type:
|
116 |
+
94.624 MB. 54.0789%. Mul
|
117 |
+
69.8255 MB. 39.9062%. Conv
|
118 |
+
5.39546 MB. 3.08357%. Add
|
119 |
+
5.12912 MB. 2.93136%. FC
|
120 |
+
174.974 MB in Total
|
121 |
+
Feature Memory Written per operator type:
|
122 |
+
55.5035 MB. 54.555%. Mul
|
123 |
+
43.5333 MB. 42.7894%. Conv
|
124 |
+
2.69773 MB. 2.65163%. Add
|
125 |
+
0.004 MB. 0.00393165%. FC
|
126 |
+
101.739 MB in Total
|
127 |
+
Parameter Memory per operator type:
|
128 |
+
25.7479 MB. 83.4024%. Conv
|
129 |
+
5.124 MB. 16.5976%. FC
|
130 |
+
0 MB. 0%. Add
|
131 |
+
0 MB. 0%. Mul
|
132 |
+
30.8719 MB in Total
|
133 |
+
```
|
134 |
+
|
135 |
+
## EfficientNet-B2
|
136 |
+
### Optimized
|
137 |
+
```
|
138 |
+
Main run finished. Milliseconds per iter: 92.28. Iters per second: 10.8366
|
139 |
+
Time per operator type:
|
140 |
+
61.4627 ms. 67.5845%. Conv
|
141 |
+
22.7458 ms. 25.0113%. Sigmoid
|
142 |
+
5.59931 ms. 6.15701%. Mul
|
143 |
+
0.642567 ms. 0.706568%. AveragePool
|
144 |
+
0.272795 ms. 0.299965%. Add
|
145 |
+
0.216178 ms. 0.237709%. FC
|
146 |
+
0.00268895 ms. 0.00295677%. Squeeze
|
147 |
+
90.942 ms in Total
|
148 |
+
FLOP per operator type:
|
149 |
+
1.98431 GFLOP. 98.9343%. Conv
|
150 |
+
0.0177039 GFLOP. 0.882686%. Mul
|
151 |
+
0.002817 GFLOP. 0.140451%. FC
|
152 |
+
0.000853984 GFLOP. 0.0425782%. Add
|
153 |
+
2.00568 GFLOP in Total
|
154 |
+
Feature Memory Read per operator type:
|
155 |
+
120.609 MB. 54.9637%. Mul
|
156 |
+
86.3512 MB. 39.3519%. Conv
|
157 |
+
6.83187 MB. 3.11341%. Add
|
158 |
+
5.64163 MB. 2.571%. FC
|
159 |
+
219.433 MB in Total
|
160 |
+
Feature Memory Written per operator type:
|
161 |
+
70.8155 MB. 54.6573%. Mul
|
162 |
+
55.3273 MB. 42.7031%. Conv
|
163 |
+
3.41594 MB. 2.63651%. Add
|
164 |
+
0.004 MB. 0.00308731%. FC
|
165 |
+
129.563 MB in Total
|
166 |
+
Parameter Memory per operator type:
|
167 |
+
30.4721 MB. 84.3913%. Conv
|
168 |
+
5.636 MB. 15.6087%. FC
|
169 |
+
0 MB. 0%. Add
|
170 |
+
0 MB. 0%. Mul
|
171 |
+
36.1081 MB in Total
|
172 |
+
```
|
173 |
+
|
174 |
+
## MixNet-M
|
175 |
+
### Optimized
|
176 |
+
```
|
177 |
+
Main run finished. Milliseconds per iter: 63.1122. Iters per second: 15.8448
|
178 |
+
Time per operator type:
|
179 |
+
48.1139 ms. 75.2052%. Conv
|
180 |
+
7.1341 ms. 11.1511%. Sigmoid
|
181 |
+
2.63706 ms. 4.12189%. SpatialBN
|
182 |
+
1.73186 ms. 2.70701%. Mul
|
183 |
+
1.38707 ms. 2.16809%. Split
|
184 |
+
1.29322 ms. 2.02139%. Concat
|
185 |
+
1.00093 ms. 1.56452%. Relu
|
186 |
+
0.235309 ms. 0.367803%. Add
|
187 |
+
0.221579 ms. 0.346343%. FC
|
188 |
+
0.219315 ms. 0.342803%. AveragePool
|
189 |
+
0.00250145 ms. 0.00390993%. Squeeze
|
190 |
+
63.9768 ms in Total
|
191 |
+
FLOP per operator type:
|
192 |
+
0.675273 GFLOP. 95.5827%. Conv
|
193 |
+
0.0221072 GFLOP. 3.12921%. SpatialBN
|
194 |
+
0.00538445 GFLOP. 0.762152%. Mul
|
195 |
+
0.003073 GFLOP. 0.434973%. FC
|
196 |
+
0.000642488 GFLOP. 0.0909421%. Add
|
197 |
+
0 GFLOP. 0%. Concat
|
198 |
+
0 GFLOP. 0%. Relu
|
199 |
+
0.70648 GFLOP in Total
|
200 |
+
Feature Memory Read per operator type:
|
201 |
+
46.8424 MB. 30.502%. Conv
|
202 |
+
36.8626 MB. 24.0036%. Mul
|
203 |
+
22.3152 MB. 14.5309%. SpatialBN
|
204 |
+
22.1074 MB. 14.3955%. Concat
|
205 |
+
14.1496 MB. 9.21372%. Relu
|
206 |
+
6.15414 MB. 4.00735%. FC
|
207 |
+
5.1399 MB. 3.34692%. Add
|
208 |
+
153.571 MB in Total
|
209 |
+
Feature Memory Written per operator type:
|
210 |
+
32.7672 MB. 28.4331%. Conv
|
211 |
+
22.1072 MB. 19.1831%. Concat
|
212 |
+
22.1072 MB. 19.1831%. SpatialBN
|
213 |
+
21.5378 MB. 18.689%. Mul
|
214 |
+
14.1496 MB. 12.2781%. Relu
|
215 |
+
2.56995 MB. 2.23003%. Add
|
216 |
+
0.004 MB. 0.00347092%. FC
|
217 |
+
115.243 MB in Total
|
218 |
+
Parameter Memory per operator type:
|
219 |
+
13.7059 MB. 68.674%. Conv
|
220 |
+
6.148 MB. 30.8049%. FC
|
221 |
+
0.104 MB. 0.521097%. SpatialBN
|
222 |
+
0 MB. 0%. Add
|
223 |
+
0 MB. 0%. Concat
|
224 |
+
0 MB. 0%. Mul
|
225 |
+
0 MB. 0%. Relu
|
226 |
+
19.9579 MB in Total
|
227 |
+
```
|
228 |
+
|
229 |
+
## TF MobileNet-V3 Large 1.0
|
230 |
+
|
231 |
+
### Optimized
|
232 |
+
```
|
233 |
+
Main run finished. Milliseconds per iter: 22.0495. Iters per second: 45.3525
|
234 |
+
Time per operator type:
|
235 |
+
17.437 ms. 80.0087%. Conv
|
236 |
+
1.27662 ms. 5.8577%. Add
|
237 |
+
1.12759 ms. 5.17387%. Div
|
238 |
+
0.701155 ms. 3.21721%. Mul
|
239 |
+
0.562654 ms. 2.58171%. Relu
|
240 |
+
0.431144 ms. 1.97828%. Clip
|
241 |
+
0.156902 ms. 0.719936%. FC
|
242 |
+
0.0996858 ms. 0.457402%. AveragePool
|
243 |
+
0.00112455 ms. 0.00515993%. Flatten
|
244 |
+
21.7939 ms in Total
|
245 |
+
FLOP per operator type:
|
246 |
+
0.43062 GFLOP. 98.1484%. Conv
|
247 |
+
0.002561 GFLOP. 0.583713%. FC
|
248 |
+
0.00210867 GFLOP. 0.480616%. Mul
|
249 |
+
0.00193868 GFLOP. 0.441871%. Add
|
250 |
+
0.00151532 GFLOP. 0.345377%. Div
|
251 |
+
0 GFLOP. 0%. Relu
|
252 |
+
0.438743 GFLOP in Total
|
253 |
+
Feature Memory Read per operator type:
|
254 |
+
34.7967 MB. 43.9391%. Conv
|
255 |
+
14.496 MB. 18.3046%. Mul
|
256 |
+
9.44828 MB. 11.9307%. Add
|
257 |
+
9.26157 MB. 11.6949%. Relu
|
258 |
+
6.0614 MB. 7.65395%. Div
|
259 |
+
5.12912 MB. 6.47673%. FC
|
260 |
+
79.193 MB in Total
|
261 |
+
Feature Memory Written per operator type:
|
262 |
+
17.6247 MB. 35.8656%. Conv
|
263 |
+
9.26157 MB. 18.847%. Relu
|
264 |
+
8.43469 MB. 17.1643%. Mul
|
265 |
+
7.75472 MB. 15.7806%. Add
|
266 |
+
6.06128 MB. 12.3345%. Div
|
267 |
+
0.004 MB. 0.00813985%. FC
|
268 |
+
49.1409 MB in Total
|
269 |
+
Parameter Memory per operator type:
|
270 |
+
16.6851 MB. 76.5052%. Conv
|
271 |
+
5.124 MB. 23.4948%. FC
|
272 |
+
0 MB. 0%. Add
|
273 |
+
0 MB. 0%. Div
|
274 |
+
0 MB. 0%. Mul
|
275 |
+
0 MB. 0%. Relu
|
276 |
+
21.8091 MB in Total
|
277 |
+
```
|
278 |
+
|
279 |
+
## MobileNet-V3 (RW)
|
280 |
+
|
281 |
+
### Unoptimized
|
282 |
+
```
|
283 |
+
Main run finished. Milliseconds per iter: 24.8316. Iters per second: 40.2712
|
284 |
+
Time per operator type:
|
285 |
+
15.9266 ms. 69.2624%. Conv
|
286 |
+
2.36551 ms. 10.2873%. SpatialBN
|
287 |
+
1.39102 ms. 6.04936%. Add
|
288 |
+
1.30327 ms. 5.66773%. Div
|
289 |
+
0.737014 ms. 3.20517%. Mul
|
290 |
+
0.639697 ms. 2.78195%. Relu
|
291 |
+
0.375681 ms. 1.63378%. Clip
|
292 |
+
0.153126 ms. 0.665921%. FC
|
293 |
+
0.0993787 ms. 0.432184%. AveragePool
|
294 |
+
0.0032632 ms. 0.0141912%. Squeeze
|
295 |
+
22.9946 ms in Total
|
296 |
+
FLOP per operator type:
|
297 |
+
0.430616 GFLOP. 94.4041%. Conv
|
298 |
+
0.0175992 GFLOP. 3.85829%. SpatialBN
|
299 |
+
0.002561 GFLOP. 0.561449%. FC
|
300 |
+
0.00210961 GFLOP. 0.46249%. Mul
|
301 |
+
0.00173891 GFLOP. 0.381223%. Add
|
302 |
+
0.00151626 GFLOP. 0.33241%. Div
|
303 |
+
0 GFLOP. 0%. Relu
|
304 |
+
0.456141 GFLOP in Total
|
305 |
+
Feature Memory Read per operator type:
|
306 |
+
34.7354 MB. 36.4363%. Conv
|
307 |
+
17.7944 MB. 18.6658%. SpatialBN
|
308 |
+
14.5035 MB. 15.2137%. Mul
|
309 |
+
9.25778 MB. 9.71113%. Relu
|
310 |
+
7.84641 MB. 8.23064%. Add
|
311 |
+
6.06516 MB. 6.36216%. Div
|
312 |
+
5.12912 MB. 5.38029%. FC
|
313 |
+
95.3317 MB in Total
|
314 |
+
Feature Memory Written per operator type:
|
315 |
+
17.6246 MB. 26.7264%. Conv
|
316 |
+
17.5992 MB. 26.6878%. SpatialBN
|
317 |
+
9.25778 MB. 14.0387%. Relu
|
318 |
+
8.43843 MB. 12.7962%. Mul
|
319 |
+
6.95565 MB. 10.5477%. Add
|
320 |
+
6.06502 MB. 9.19713%. Div
|
321 |
+
0.004 MB. 0.00606568%. FC
|
322 |
+
65.9447 MB in Total
|
323 |
+
Parameter Memory per operator type:
|
324 |
+
16.6778 MB. 76.1564%. Conv
|
325 |
+
5.124 MB. 23.3979%. FC
|
326 |
+
0.0976 MB. 0.445674%. SpatialBN
|
327 |
+
0 MB. 0%. Add
|
328 |
+
0 MB. 0%. Div
|
329 |
+
0 MB. 0%. Mul
|
330 |
+
0 MB. 0%. Relu
|
331 |
+
21.8994 MB in Total
|
332 |
+
|
333 |
+
```
|
334 |
+
### Optimized
|
335 |
+
|
336 |
+
```
|
337 |
+
Main run finished. Milliseconds per iter: 22.0981. Iters per second: 45.2527
|
338 |
+
Time per operator type:
|
339 |
+
17.146 ms. 78.8965%. Conv
|
340 |
+
1.38453 ms. 6.37084%. Add
|
341 |
+
1.30991 ms. 6.02749%. Div
|
342 |
+
0.685417 ms. 3.15391%. Mul
|
343 |
+
0.532589 ms. 2.45068%. Relu
|
344 |
+
0.418263 ms. 1.92461%. Clip
|
345 |
+
0.15128 ms. 0.696106%. FC
|
346 |
+
0.102065 ms. 0.469648%. AveragePool
|
347 |
+
0.0022143 ms. 0.010189%. Squeeze
|
348 |
+
21.7323 ms in Total
|
349 |
+
FLOP per operator type:
|
350 |
+
0.430616 GFLOP. 98.1927%. Conv
|
351 |
+
0.002561 GFLOP. 0.583981%. FC
|
352 |
+
0.00210961 GFLOP. 0.481051%. Mul
|
353 |
+
0.00173891 GFLOP. 0.396522%. Add
|
354 |
+
0.00151626 GFLOP. 0.34575%. Div
|
355 |
+
0 GFLOP. 0%. Relu
|
356 |
+
0.438542 GFLOP in Total
|
357 |
+
Feature Memory Read per operator type:
|
358 |
+
34.7842 MB. 44.833%. Conv
|
359 |
+
14.5035 MB. 18.6934%. Mul
|
360 |
+
9.25778 MB. 11.9323%. Relu
|
361 |
+
7.84641 MB. 10.1132%. Add
|
362 |
+
6.06516 MB. 7.81733%. Div
|
363 |
+
5.12912 MB. 6.61087%. FC
|
364 |
+
77.5861 MB in Total
|
365 |
+
Feature Memory Written per operator type:
|
366 |
+
17.6246 MB. 36.4556%. Conv
|
367 |
+
9.25778 MB. 19.1492%. Relu
|
368 |
+
8.43843 MB. 17.4544%. Mul
|
369 |
+
6.95565 MB. 14.3874%. Add
|
370 |
+
6.06502 MB. 12.5452%. Div
|
371 |
+
0.004 MB. 0.00827378%. FC
|
372 |
+
48.3455 MB in Total
|
373 |
+
Parameter Memory per operator type:
|
374 |
+
16.6778 MB. 76.4973%. Conv
|
375 |
+
5.124 MB. 23.5027%. FC
|
376 |
+
0 MB. 0%. Add
|
377 |
+
0 MB. 0%. Div
|
378 |
+
0 MB. 0%. Mul
|
379 |
+
0 MB. 0%. Relu
|
380 |
+
21.8018 MB in Total
|
381 |
+
|
382 |
+
```
|
383 |
+
|
384 |
+
## MnasNet-A1
|
385 |
+
|
386 |
+
### Unoptimized
|
387 |
+
```
|
388 |
+
Main run finished. Milliseconds per iter: 30.0892. Iters per second: 33.2345
|
389 |
+
Time per operator type:
|
390 |
+
24.4656 ms. 79.0905%. Conv
|
391 |
+
4.14958 ms. 13.4144%. SpatialBN
|
392 |
+
1.60598 ms. 5.19169%. Relu
|
393 |
+
0.295219 ms. 0.95436%. Mul
|
394 |
+
0.187609 ms. 0.606486%. FC
|
395 |
+
0.120556 ms. 0.389724%. AveragePool
|
396 |
+
0.09036 ms. 0.292109%. Add
|
397 |
+
0.015727 ms. 0.050841%. Sigmoid
|
398 |
+
0.00306205 ms. 0.00989875%. Squeeze
|
399 |
+
30.9337 ms in Total
|
400 |
+
FLOP per operator type:
|
401 |
+
0.620598 GFLOP. 95.6434%. Conv
|
402 |
+
0.0248873 GFLOP. 3.8355%. SpatialBN
|
403 |
+
0.002561 GFLOP. 0.394688%. FC
|
404 |
+
0.000597408 GFLOP. 0.0920695%. Mul
|
405 |
+
0.000222656 GFLOP. 0.0343146%. Add
|
406 |
+
0 GFLOP. 0%. Relu
|
407 |
+
0.648867 GFLOP in Total
|
408 |
+
Feature Memory Read per operator type:
|
409 |
+
35.5457 MB. 38.4109%. Conv
|
410 |
+
25.1552 MB. 27.1829%. SpatialBN
|
411 |
+
22.5235 MB. 24.339%. Relu
|
412 |
+
5.12912 MB. 5.54256%. FC
|
413 |
+
2.40586 MB. 2.59978%. Mul
|
414 |
+
1.78125 MB. 1.92483%. Add
|
415 |
+
92.5406 MB in Total
|
416 |
+
Feature Memory Written per operator type:
|
417 |
+
24.9042 MB. 32.9424%. Conv
|
418 |
+
24.8873 MB. 32.92%. SpatialBN
|
419 |
+
22.5235 MB. 29.7932%. Relu
|
420 |
+
2.38963 MB. 3.16092%. Mul
|
421 |
+
0.890624 MB. 1.17809%. Add
|
422 |
+
0.004 MB. 0.00529106%. FC
|
423 |
+
75.5993 MB in Total
|
424 |
+
Parameter Memory per operator type:
|
425 |
+
10.2732 MB. 66.1459%. Conv
|
426 |
+
5.124 MB. 32.9917%. FC
|
427 |
+
0.133952 MB. 0.86247%. SpatialBN
|
428 |
+
0 MB. 0%. Add
|
429 |
+
0 MB. 0%. Mul
|
430 |
+
0 MB. 0%. Relu
|
431 |
+
15.5312 MB in Total
|
432 |
+
```
|
433 |
+
|
434 |
+
### Optimized
|
435 |
+
```
|
436 |
+
Main run finished. Milliseconds per iter: 24.2367. Iters per second: 41.2597
|
437 |
+
Time per operator type:
|
438 |
+
22.0547 ms. 91.1375%. Conv
|
439 |
+
1.49096 ms. 6.16116%. Relu
|
440 |
+
0.253417 ms. 1.0472%. Mul
|
441 |
+
0.18506 ms. 0.76473%. FC
|
442 |
+
0.112942 ms. 0.466717%. AveragePool
|
443 |
+
0.086769 ms. 0.358559%. Add
|
444 |
+
0.0127889 ms. 0.0528479%. Sigmoid
|
445 |
+
0.0027346 ms. 0.0113003%. Squeeze
|
446 |
+
24.1994 ms in Total
|
447 |
+
FLOP per operator type:
|
448 |
+
0.620598 GFLOP. 99.4581%. Conv
|
449 |
+
0.002561 GFLOP. 0.41043%. FC
|
450 |
+
0.000597408 GFLOP. 0.0957417%. Mul
|
451 |
+
0.000222656 GFLOP. 0.0356832%. Add
|
452 |
+
0 GFLOP. 0%. Relu
|
453 |
+
0.623979 GFLOP in Total
|
454 |
+
Feature Memory Read per operator type:
|
455 |
+
35.6127 MB. 52.7968%. Conv
|
456 |
+
22.5235 MB. 33.3917%. Relu
|
457 |
+
5.12912 MB. 7.60406%. FC
|
458 |
+
2.40586 MB. 3.56675%. Mul
|
459 |
+
1.78125 MB. 2.64075%. Add
|
460 |
+
67.4524 MB in Total
|
461 |
+
Feature Memory Written per operator type:
|
462 |
+
24.9042 MB. 49.1092%. Conv
|
463 |
+
22.5235 MB. 44.4145%. Relu
|
464 |
+
2.38963 MB. 4.71216%. Mul
|
465 |
+
0.890624 MB. 1.75624%. Add
|
466 |
+
0.004 MB. 0.00788768%. FC
|
467 |
+
50.712 MB in Total
|
468 |
+
Parameter Memory per operator type:
|
469 |
+
10.2732 MB. 66.7213%. Conv
|
470 |
+
5.124 MB. 33.2787%. FC
|
471 |
+
0 MB. 0%. Add
|
472 |
+
0 MB. 0%. Mul
|
473 |
+
0 MB. 0%. Relu
|
474 |
+
15.3972 MB in Total
|
475 |
+
```
|
476 |
+
## MnasNet-B1
|
477 |
+
|
478 |
+
### Unoptimized
|
479 |
+
```
|
480 |
+
Main run finished. Milliseconds per iter: 28.3109. Iters per second: 35.322
|
481 |
+
Time per operator type:
|
482 |
+
29.1121 ms. 83.3081%. Conv
|
483 |
+
4.14959 ms. 11.8746%. SpatialBN
|
484 |
+
1.35823 ms. 3.88675%. Relu
|
485 |
+
0.186188 ms. 0.532802%. FC
|
486 |
+
0.116244 ms. 0.332647%. Add
|
487 |
+
0.018641 ms. 0.0533437%. AveragePool
|
488 |
+
0.0040904 ms. 0.0117052%. Squeeze
|
489 |
+
34.9451 ms in Total
|
490 |
+
FLOP per operator type:
|
491 |
+
0.626272 GFLOP. 96.2088%. Conv
|
492 |
+
0.0218266 GFLOP. 3.35303%. SpatialBN
|
493 |
+
0.002561 GFLOP. 0.393424%. FC
|
494 |
+
0.000291648 GFLOP. 0.0448034%. Add
|
495 |
+
0 GFLOP. 0%. Relu
|
496 |
+
0.650951 GFLOP in Total
|
497 |
+
Feature Memory Read per operator type:
|
498 |
+
34.4354 MB. 41.3788%. Conv
|
499 |
+
22.1299 MB. 26.5921%. SpatialBN
|
500 |
+
19.1923 MB. 23.0622%. Relu
|
501 |
+
5.12912 MB. 6.16333%. FC
|
502 |
+
2.33318 MB. 2.80364%. Add
|
503 |
+
83.2199 MB in Total
|
504 |
+
Feature Memory Written per operator type:
|
505 |
+
21.8266 MB. 34.0955%. Conv
|
506 |
+
21.8266 MB. 34.0955%. SpatialBN
|
507 |
+
19.1923 MB. 29.9805%. Relu
|
508 |
+
1.16659 MB. 1.82234%. Add
|
509 |
+
0.004 MB. 0.00624844%. FC
|
510 |
+
64.016 MB in Total
|
511 |
+
Parameter Memory per operator type:
|
512 |
+
12.2576 MB. 69.9104%. Conv
|
513 |
+
5.124 MB. 29.2245%. FC
|
514 |
+
0.15168 MB. 0.865099%. SpatialBN
|
515 |
+
0 MB. 0%. Add
|
516 |
+
0 MB. 0%. Relu
|
517 |
+
17.5332 MB in Total
|
518 |
+
```
|
519 |
+
|
520 |
+
### Optimized
|
521 |
+
```
|
522 |
+
Main run finished. Milliseconds per iter: 26.6364. Iters per second: 37.5426
|
523 |
+
Time per operator type:
|
524 |
+
24.9888 ms. 94.0962%. Conv
|
525 |
+
1.26147 ms. 4.75011%. Relu
|
526 |
+
0.176234 ms. 0.663619%. FC
|
527 |
+
0.113309 ms. 0.426672%. Add
|
528 |
+
0.0138708 ms. 0.0522311%. AveragePool
|
529 |
+
0.00295685 ms. 0.0111341%. Squeeze
|
530 |
+
26.5566 ms in Total
|
531 |
+
FLOP per operator type:
|
532 |
+
0.626272 GFLOP. 99.5466%. Conv
|
533 |
+
0.002561 GFLOP. 0.407074%. FC
|
534 |
+
0.000291648 GFLOP. 0.0463578%. Add
|
535 |
+
0 GFLOP. 0%. Relu
|
536 |
+
0.629124 GFLOP in Total
|
537 |
+
Feature Memory Read per operator type:
|
538 |
+
34.5112 MB. 56.4224%. Conv
|
539 |
+
19.1923 MB. 31.3775%. Relu
|
540 |
+
5.12912 MB. 8.3856%. FC
|
541 |
+
2.33318 MB. 3.81452%. Add
|
542 |
+
61.1658 MB in Total
|
543 |
+
Feature Memory Written per operator type:
|
544 |
+
21.8266 MB. 51.7346%. Conv
|
545 |
+
19.1923 MB. 45.4908%. Relu
|
546 |
+
1.16659 MB. 2.76513%. Add
|
547 |
+
0.004 MB. 0.00948104%. FC
|
548 |
+
42.1895 MB in Total
|
549 |
+
Parameter Memory per operator type:
|
550 |
+
12.2576 MB. 70.5205%. Conv
|
551 |
+
5.124 MB. 29.4795%. FC
|
552 |
+
0 MB. 0%. Add
|
553 |
+
0 MB. 0%. Relu
|
554 |
+
17.3816 MB in Total
|
555 |
+
```
|
annotator/normalbae/models/submodules/efficientnet_repo/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright 2020 Ross Wightman
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
annotator/normalbae/models/submodules/efficientnet_repo/README.md
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# (Generic) EfficientNets for PyTorch
|
2 |
+
|
3 |
+
A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter efficient architectures derived from the MobileNet V1/V2 block sequence, including those found via automated neural architecture search.
|
4 |
+
|
5 |
+
All models are implemented by GenEfficientNet or MobileNetV3 classes, with string based architecture definitions to configure the block layouts (idea from [here](https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py))
|
6 |
+
|
7 |
+
## What's New
|
8 |
+
|
9 |
+
### Aug 19, 2020
|
10 |
+
* Add updated PyTorch trained EfficientNet-B3 weights trained by myself with `timm` (82.1 top-1)
|
11 |
+
* Add PyTorch trained EfficientNet-Lite0 contributed by [@hal-314](https://github.com/hal-314) (75.5 top-1)
|
12 |
+
* Update ONNX and Caffe2 export / utility scripts to work with latest PyTorch / ONNX
|
13 |
+
* ONNX runtime based validation script added
|
14 |
+
* activations (mostly) brought in sync with `timm` equivalents
|
15 |
+
|
16 |
+
|
17 |
+
### April 5, 2020
|
18 |
+
* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite
|
19 |
+
* 3.5M param MobileNet-V2 100 @ 73%
|
20 |
+
* 4.5M param MobileNet-V2 110d @ 75%
|
21 |
+
* 6.1M param MobileNet-V2 140 @ 76.5%
|
22 |
+
* 5.8M param MobileNet-V2 120d @ 77.3%
|
23 |
+
|
24 |
+
### March 23, 2020
|
25 |
+
* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
|
26 |
+
* Add PyTorch trained MobileNet-V3 Large weights with 75.77% top-1
|
27 |
+
* IMPORTANT CHANGE (if training from scratch) - weight init changed to better match Tensorflow impl, set `fix_group_fanout=False` in `initialize_weight_goog` for old behavior
|
28 |
+
|
29 |
+
### Feb 12, 2020
|
30 |
+
* Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
|
31 |
+
* Port new EfficientNet-B8 (RandAugment) weights from TF TPU, these are different than the B8 AdvProp, different input normalization.
|
32 |
+
* Add RandAugment PyTorch trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin)
|
33 |
+
|
34 |
+
### Jan 22, 2020
|
35 |
+
* Update weights for EfficientNet B0, B2, B3 and MixNet-XL with latest RandAugment trained weights. Trained with (https://github.com/rwightman/pytorch-image-models)
|
36 |
+
* Fix torchscript compatibility for PyTorch 1.4, add torchscript support for MixedConv2d using ModuleDict
|
37 |
+
* Test models, torchscript, onnx export with PyTorch 1.4 -- no issues
|
38 |
+
|
39 |
+
### Nov 22, 2019
|
40 |
+
* New top-1 high! Ported official TF EfficientNet AdvProp (https://arxiv.org/abs/1911.09665) weights and B8 model spec. Created a new set of `ap` models since they use a different
|
41 |
+
preprocessing (Inception mean/std) from the original EfficientNet base/AA/RA weights.
|
42 |
+
|
43 |
+
### Nov 15, 2019
|
44 |
+
* Ported official TF MobileNet-V3 float32 large/small/minimalistic weights
|
45 |
+
* Modifications to MobileNet-V3 model and components to support some additional config needed for differences between TF MobileNet-V3 and mine
|
46 |
+
|
47 |
+
### Oct 30, 2019
|
48 |
+
* Many of the models will now work with torch.jit.script, MixNet being the biggest exception
|
49 |
+
* Improved interface for enabling torchscript or ONNX export compatible modes (via config)
|
50 |
+
* Add JIT optimized mem-efficient Swish/Mish autograd.fn in addition to memory-efficient autgrad.fn
|
51 |
+
* Activation factory to select best version of activation by name or override one globally
|
52 |
+
* Add pretrained checkpoint load helper that handles input conv and classifier changes
|
53 |
+
|
54 |
+
### Oct 27, 2019
|
55 |
+
* Add CondConv EfficientNet variants ported from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
|
56 |
+
* Add RandAug weights for TF EfficientNet B5 and B7 from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
|
57 |
+
* Bring over MixNet-XL model and depth scaling algo from my pytorch-image-models code base
|
58 |
+
* Switch activations and global pooling to modules
|
59 |
+
* Add memory-efficient Swish/Mish impl
|
60 |
+
* Add as_sequential() method to all models and allow as an argument in entrypoint fns
|
61 |
+
* Move MobileNetV3 into own file since it has a different head
|
62 |
+
* Remove ChamNet, MobileNet V2/V1 since they will likely never be used here
|
63 |
+
|
64 |
+
## Models
|
65 |
+
|
66 |
+
Implemented models include:
|
67 |
+
* EfficientNet NoisyStudent (B0-B7, L2) (https://arxiv.org/abs/1911.04252)
|
68 |
+
* EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665)
|
69 |
+
* EfficientNet (B0-B8) (https://arxiv.org/abs/1905.11946)
|
70 |
+
* EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html)
|
71 |
+
* EfficientNet-CondConv (https://arxiv.org/abs/1904.04971)
|
72 |
+
* EfficientNet-Lite (https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
|
73 |
+
* MixNet (https://arxiv.org/abs/1907.09595)
|
74 |
+
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
|
75 |
+
* MobileNet-V3 (https://arxiv.org/abs/1905.02244)
|
76 |
+
* FBNet-C (https://arxiv.org/abs/1812.03443)
|
77 |
+
* Single-Path NAS (https://arxiv.org/abs/1904.02877)
|
78 |
+
|
79 |
+
I originally implemented and trained some these models with code [here](https://github.com/rwightman/pytorch-image-models), this repository contains just the GenEfficientNet models, validation, and associated ONNX/Caffe2 export code.
|
80 |
+
|
81 |
+
## Pretrained
|
82 |
+
|
83 |
+
I've managed to train several of the models to accuracies close to or above the originating papers and official impl. My training code is here: https://github.com/rwightman/pytorch-image-models
|
84 |
+
|
85 |
+
|
86 |
+
|Model | Prec@1 (Err) | Prec@5 (Err) | Param#(M) | MAdds(M) | Image Scaling | Resolution | Crop |
|
87 |
+
|---|---|---|---|---|---|---|---|
|
88 |
+
| efficientnet_b3 | 82.240 (17.760) | 96.116 (3.884) | 12.23 | TBD | bicubic | 320 | 1.0 |
|
89 |
+
| efficientnet_b3 | 82.076 (17.924) | 96.020 (3.980) | 12.23 | TBD | bicubic | 300 | 0.904 |
|
90 |
+
| mixnet_xl | 81.074 (18.926) | 95.282 (4.718) | 11.90 | TBD | bicubic | 256 | 1.0 |
|
91 |
+
| efficientnet_b2 | 80.612 (19.388) | 95.318 (4.682) | 9.1 | TBD | bicubic | 288 | 1.0 |
|
92 |
+
| mixnet_xl | 80.476 (19.524) | 94.936 (5.064) | 11.90 | TBD | bicubic | 224 | 0.875 |
|
93 |
+
| efficientnet_b2 | 80.288 (19.712) | 95.166 (4.834) | 9.1 | 1003 | bicubic | 260 | 0.890 |
|
94 |
+
| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33 | TBD | bicubic | 224 | 0.875 |
|
95 |
+
| efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.8 | 694 | bicubic | 240 | 0.882 |
|
96 |
+
| efficientnet_es | 78.066 (21.934) | 93.926 (6.074) | 5.44 | TBD | bicubic | 224 | 0.875 |
|
97 |
+
| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.3 | 390 | bicubic | 224 | 0.875 |
|
98 |
+
| mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8 | TBD | bicubic | 224 | 0.875 |
|
99 |
+
| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01 | 353 | bicubic | 224 | 0.875 |
|
100 |
+
| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1 | TBD | bicubic | 224 | 0.875 |
|
101 |
+
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13 | TBD | bicubic | 224 | 0.875 |
|
102 |
+
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5 | TBD | bicubic | 224 | 0.875 |
|
103 |
+
| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5 | 219 | bicubic | 224 | 0.875 |
|
104 |
+
| efficientnet_lite0 | 75.472 (24.528) | 92.520 (7.480) | 4.65 | TBD | bicubic | 224 | 0.875 |
|
105 |
+
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.9 | 312 | bicubic | 224 | 0.875 |
|
106 |
+
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6 | 385 | bilinear | 224 | 0.875 |
|
107 |
+
| mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5 | TBD | bicubic | 224 | 0.875 |
|
108 |
+
| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.4 | 315 | bicubic | 224 | 0.875 |
|
109 |
+
| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.4 | TBD | bilinear | 224 | 0.875 |
|
110 |
+
| mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5 | TBD | bicubic | 224 | 0.875 |
|
111 |
+
|
112 |
+
|
113 |
+
More pretrained models to come...
|
114 |
+
|
115 |
+
|
116 |
+
## Ported Weights
|
117 |
+
|
118 |
+
The weights ported from Tensorflow checkpoints for the EfficientNet models do pretty much match accuracy in Tensorflow once a SAME convolution padding equivalent is added, and the same crop factors, image scaling, etc (see table) are used via cmd line args.
|
119 |
+
|
120 |
+
**IMPORTANT:**
|
121 |
+
* Tensorflow ported weights for EfficientNet AdvProp (AP), EfficientNet EdgeTPU, EfficientNet-CondConv, EfficientNet-Lite, and MobileNet-V3 models use Inception style (0.5, 0.5, 0.5) for mean and std.
|
122 |
+
* Enabling the Tensorflow preprocessing pipeline with `--tf-preprocessing` at validation time will improve scores by 0.1-0.5%, very close to original TF impl.
|
123 |
+
|
124 |
+
To run validation for tf_efficientnet_b5:
|
125 |
+
`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --crop-pct 0.934 --interpolation bicubic`
|
126 |
+
|
127 |
+
To run validation w/ TF preprocessing for tf_efficientnet_b5:
|
128 |
+
`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --tf-preprocessing`
|
129 |
+
|
130 |
+
To run validation for a model with Inception preprocessing, ie EfficientNet-B8 AdvProp:
|
131 |
+
`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b8_ap -b 48 --num-gpu 2 --img-size 672 --crop-pct 0.954 --mean 0.5 --std 0.5`
|
132 |
+
|
133 |
+
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size | Crop |
|
134 |
+
|---|---|---|---|---|---|---|
|
135 |
+
| tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 | N/A |
|
136 |
+
| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 | 0.961 |
|
137 |
+
| tf_efficientnet_l2_ns_475 | 88.234 (11.766) | 98.546 (1.454) | 480 | bicubic | 475 | 0.936 |
|
138 |
+
| tf_efficientnet_l2_ns_475 *tfp | 88.172 (11.828) | 98.566 (1.434) | 480 | bicubic | 475 | N/A |
|
139 |
+
| tf_efficientnet_b7_ns *tfp | 86.844 (13.156) | 98.084 (1.916) | 66.35 | bicubic | 600 | N/A |
|
140 |
+
| tf_efficientnet_b7_ns | 86.840 (13.160) | 98.094 (1.906) | 66.35 | bicubic | 600 | N/A |
|
141 |
+
| tf_efficientnet_b6_ns | 86.452 (13.548) | 97.882 (2.118) | 43.04 | bicubic | 528 | N/A |
|
142 |
+
| tf_efficientnet_b6_ns *tfp | 86.444 (13.556) | 97.880 (2.120) | 43.04 | bicubic | 528 | N/A |
|
143 |
+
| tf_efficientnet_b5_ns *tfp | 86.064 (13.936) | 97.746 (2.254) | 30.39 | bicubic | 456 | N/A |
|
144 |
+
| tf_efficientnet_b5_ns | 86.088 (13.912) | 97.752 (2.248) | 30.39 | bicubic | 456 | N/A |
|
145 |
+
| tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 | N/A |
|
146 |
+
| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 | N/A |
|
147 |
+
| tf_efficientnet_b8 | 85.370 (14.630) | 97.390 (2.610) | 87.4 | bicubic | 672 | 0.954 |
|
148 |
+
| tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 | 0.954 |
|
149 |
+
| tf_efficientnet_b4_ns *tfp | 85.298 (14.702) | 97.504 (2.496) | 19.34 | bicubic | 380 | N/A |
|
150 |
+
| tf_efficientnet_b4_ns | 85.162 (14.838) | 97.470 (2.530) | 19.34 | bicubic | 380 | 0.922 |
|
151 |
+
| tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 | N/A |
|
152 |
+
| tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 | 0.949 |
|
153 |
+
| tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 | N/A |
|
154 |
+
| tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 | 0.949 |
|
155 |
+
| tf_efficientnet_b6_ap | 84.786 (15.214) | 97.138 (2.862) | 43.04 | bicubic | 528 | 0.942 |
|
156 |
+
| tf_efficientnet_b6_ap *tfp | 84.760 (15.240) | 97.124 (2.876) | 43.04 | bicubic | 528 | N/A |
|
157 |
+
| tf_efficientnet_b5_ap *tfp | 84.276 (15.724) | 96.932 (3.068) | 30.39 | bicubic | 456 | N/A |
|
158 |
+
| tf_efficientnet_b5_ap | 84.254 (15.746) | 96.976 (3.024) | 30.39 | bicubic | 456 | 0.934 |
|
159 |
+
| tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 | N/A |
|
160 |
+
| tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 | 0.942 |
|
161 |
+
| tf_efficientnet_b3_ns *tfp | 84.054 (15.946) | 96.918 (3.082) | 12.23 | bicubic | 300 | N/A |
|
162 |
+
| tf_efficientnet_b3_ns | 84.048 (15.952) | 96.910 (3.090) | 12.23 | bicubic | 300 | .904 |
|
163 |
+
| tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 | N/A |
|
164 |
+
| tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 | 0.934 |
|
165 |
+
| tf_efficientnet_b4_ap *tfp | 83.278 (16.722) | 96.376 (3.624) | 19.34 | bicubic | 380 | N/A |
|
166 |
+
| tf_efficientnet_b4_ap | 83.248 (16.752) | 96.388 (3.612) | 19.34 | bicubic | 380 | 0.922 |
|
167 |
+
| tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 | 0.922 |
|
168 |
+
| tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 | N/A |
|
169 |
+
| tf_efficientnet_b2_ns *tfp | 82.436 (17.564) | 96.268 (3.732) | 9.11 | bicubic | 260 | N/A |
|
170 |
+
| tf_efficientnet_b2_ns | 82.380 (17.620) | 96.248 (3.752) | 9.11 | bicubic | 260 | 0.89 |
|
171 |
+
| tf_efficientnet_b3_ap *tfp | 81.882 (18.118) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A |
|
172 |
+
| tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 | 0.904 |
|
173 |
+
| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 | 0.904 |
|
174 |
+
| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A |
|
175 |
+
| tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 | 0.92 |
|
176 |
+
| tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 | N/A |
|
177 |
+
| tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 | N/A |
|
178 |
+
| tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 | 0.88 |
|
179 |
+
| tf_efficientnet_el | 80.534 (19.466) | 95.190 (4.810) | 10.59 | bicubic | 300 | 0.904 |
|
180 |
+
| tf_efficientnet_el *tfp | 80.476 (19.524) | 95.200 (4.800) | 10.59 | bicubic | 300 | N/A |
|
181 |
+
| tf_efficientnet_b2_ap *tfp | 80.420 (19.580) | 95.040 (4.960) | 9.11 | bicubic | 260 | N/A |
|
182 |
+
| tf_efficientnet_b2_ap | 80.306 (19.694) | 95.028 (4.972) | 9.11 | bicubic | 260 | 0.890 |
|
183 |
+
| tf_efficientnet_b2 *tfp | 80.188 (19.812) | 94.974 (5.026) | 9.11 | bicubic | 260 | N/A |
|
184 |
+
| tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 | 0.890 |
|
185 |
+
| tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 | 0.904 |
|
186 |
+
| tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 | N/A |
|
187 |
+
| tf_efficientnet_b1_ap *tfp | 79.532 (20.468) | 94.378 (5.622) | 7.79 | bicubic | 240 | N/A |
|
188 |
+
| tf_efficientnet_cc_b1_8e *tfp | 79.464 (20.536)| 94.492 (5.508) | 39.7 | bicubic | 240 | 0.88 |
|
189 |
+
| tf_efficientnet_cc_b1_8e | 79.298 (20.702) | 94.364 (5.636) | 39.7 | bicubic | 240 | 0.88 |
|
190 |
+
| tf_efficientnet_b1_ap | 79.278 (20.722) | 94.308 (5.692) | 7.79 | bicubic | 240 | 0.88 |
|
191 |
+
| tf_efficientnet_b1 *tfp | 79.172 (20.828) | 94.450 (5.550) | 7.79 | bicubic | 240 | N/A |
|
192 |
+
| tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 | N/A |
|
193 |
+
| tf_efficientnet_b0_ns *tfp | 78.806 (21.194) | 94.496 (5.504) | 5.29 | bicubic | 224 | N/A |
|
194 |
+
| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 | N/A |
|
195 |
+
| tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 | 0.88 |
|
196 |
+
| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 | 0.875 |
|
197 |
+
| tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 | 0.875 |
|
198 |
+
| tf_efficientnet_b0_ns | 78.658 (21.342) | 94.376 (5.624) | 5.29 | bicubic | 224 | 0.875 |
|
199 |
+
| tf_efficientnet_cc_b0_8e *tfp | 78.314 (21.686) | 93.790 (6.210) | 24.0 | bicubic | 224 | 0.875 |
|
200 |
+
| tf_efficientnet_cc_b0_8e | 77.908 (22.092) | 93.656 (6.344) | 24.0 | bicubic | 224 | 0.875 |
|
201 |
+
| tf_efficientnet_cc_b0_4e *tfp | 77.746 (22.254) | 93.552 (6.448) | 13.3 | bicubic | 224 | 0.875 |
|
202 |
+
| tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 | 0.875 |
|
203 |
+
| tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | N/A |
|
204 |
+
| tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 | N/A |
|
205 |
+
| tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 | 0.89 |
|
206 |
+
| tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 | N/A |
|
207 |
+
| tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | N/A |
|
208 |
+
| tf_efficientnet_b0 *tfp | 77.258 (22.742) | 93.478 (6.522) | 5.29 | bicubic | 224 | N/A |
|
209 |
+
| tf_efficientnet_b0_ap | 77.084 (22.916) | 93.254 (6.746) | 5.29 | bicubic | 224 | 0.875 |
|
210 |
+
| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | N/A |
|
211 |
+
| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | 0.875 |
|
212 |
+
| tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 | 0.875 |
|
213 |
+
| tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 | N/A |
|
214 |
+
| tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 | 0.882 |
|
215 |
+
| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 | N/A |
|
216 |
+
| tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 | N/A |
|
217 |
+
| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 | 0.875 |
|
218 |
+
| tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 | 0.875 |
|
219 |
+
| tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 | N/A |
|
220 |
+
| tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 | 0.875 |
|
221 |
+
| tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 |N/A |
|
222 |
+
| tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 | 0.875 |
|
223 |
+
| tf_mobilenetv3_large_minimal_100 *tfp | 72.678 (27.322) | 90.860 (9.140) | 3.92 | bilinear | 224 | N/A |
|
224 |
+
| tf_mobilenetv3_large_minimal_100 | 72.244 (27.756) | 90.636 (9.364) | 3.92 | bilinear | 224 | 0.875 |
|
225 |
+
| tf_mobilenetv3_small_100 *tfp | 67.918 (32.082) | 87.958 (12.042 | 2.54 | bilinear | 224 | N/A |
|
226 |
+
| tf_mobilenetv3_small_100 | 67.918 (32.082) | 87.662 (12.338) | 2.54 | bilinear | 224 | 0.875 |
|
227 |
+
| tf_mobilenetv3_small_075 *tfp | 66.142 (33.858) | 86.498 (13.502) | 2.04 | bilinear | 224 | N/A |
|
228 |
+
| tf_mobilenetv3_small_075 | 65.718 (34.282) | 86.136 (13.864) | 2.04 | bilinear | 224 | 0.875 |
|
229 |
+
| tf_mobilenetv3_small_minimal_100 *tfp | 63.378 (36.622) | 84.802 (15.198) | 2.04 | bilinear | 224 | N/A |
|
230 |
+
| tf_mobilenetv3_small_minimal_100 | 62.898 (37.102) | 84.230 (15.770) | 2.04 | bilinear | 224 | 0.875 |
|
231 |
+
|
232 |
+
|
233 |
+
*tfp models validated with `tf-preprocessing` pipeline
|
234 |
+
|
235 |
+
Google tf and tflite weights ported from official Tensorflow repositories
|
236 |
+
* https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
|
237 |
+
* https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
|
238 |
+
* https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet
|
239 |
+
|
240 |
+
## Usage
|
241 |
+
|
242 |
+
### Environment
|
243 |
+
|
244 |
+
All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x, 3.8.x.
|
245 |
+
|
246 |
+
Users have reported that a Python 3 Anaconda install in Windows works. I have not verified this myself.
|
247 |
+
|
248 |
+
PyTorch versions 1.4, 1.5, 1.6 have been tested with this code.
|
249 |
+
|
250 |
+
I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda:
|
251 |
+
```
|
252 |
+
conda create -n torch-env
|
253 |
+
conda activate torch-env
|
254 |
+
conda install -c pytorch pytorch torchvision cudatoolkit=10.2
|
255 |
+
```
|
256 |
+
|
257 |
+
### PyTorch Hub
|
258 |
+
|
259 |
+
Models can be accessed via the PyTorch Hub API
|
260 |
+
|
261 |
+
```
|
262 |
+
>>> torch.hub.list('rwightman/gen-efficientnet-pytorch')
|
263 |
+
['efficientnet_b0', ...]
|
264 |
+
>>> model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True)
|
265 |
+
>>> model.eval()
|
266 |
+
>>> output = model(torch.randn(1,3,224,224))
|
267 |
+
```
|
268 |
+
|
269 |
+
### Pip
|
270 |
+
This package can be installed via pip.
|
271 |
+
|
272 |
+
Install (after conda env/install):
|
273 |
+
```
|
274 |
+
pip install geffnet
|
275 |
+
```
|
276 |
+
|
277 |
+
Eval use:
|
278 |
+
```
|
279 |
+
>>> import geffnet
|
280 |
+
>>> m = geffnet.create_model('mobilenetv3_large_100', pretrained=True)
|
281 |
+
>>> m.eval()
|
282 |
+
```
|
283 |
+
|
284 |
+
Train use:
|
285 |
+
```
|
286 |
+
>>> import geffnet
|
287 |
+
>>> # models can also be created by using the entrypoint directly
|
288 |
+
>>> m = geffnet.efficientnet_b2(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2)
|
289 |
+
>>> m.train()
|
290 |
+
```
|
291 |
+
|
292 |
+
Create in a nn.Sequential container, for fast.ai, etc:
|
293 |
+
```
|
294 |
+
>>> import geffnet
|
295 |
+
>>> m = geffnet.mixnet_l(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2, as_sequential=True)
|
296 |
+
```
|
297 |
+
|
298 |
+
### Exporting
|
299 |
+
|
300 |
+
Scripts are included to
|
301 |
+
* export models to ONNX (`onnx_export.py`)
|
302 |
+
* optimized ONNX graph (`onnx_optimize.py` or `onnx_validate.py` w/ `--onnx-output-opt` arg)
|
303 |
+
* validate with ONNX runtime (`onnx_validate.py`)
|
304 |
+
* convert ONNX model to Caffe2 (`onnx_to_caffe.py`)
|
305 |
+
* validate in Caffe2 (`caffe2_validate.py`)
|
306 |
+
* benchmark in Caffe2 w/ FLOPs, parameters output (`caffe2_benchmark.py`)
|
307 |
+
|
308 |
+
As an example, to export the MobileNet-V3 pretrained model and then run an Imagenet validation:
|
309 |
+
```
|
310 |
+
python onnx_export.py --model mobilenetv3_large_100 ./mobilenetv3_100.onnx
|
311 |
+
python onnx_validate.py /imagenet/validation/ --onnx-input ./mobilenetv3_100.onnx
|
312 |
+
```
|
313 |
+
|
314 |
+
These scripts were tested to be working as of PyTorch 1.6 and ONNX 1.7 w/ ONNX runtime 1.4. Caffe2 compatible
|
315 |
+
export now requires additional args mentioned in the export script (not needed in earlier versions).
|
316 |
+
|
317 |
+
#### Export Notes
|
318 |
+
1. The TF ported weights with the 'SAME' conv padding activated cannot be exported to ONNX unless `_EXPORTABLE` flag in `config.py` is set to True. Use `config.set_exportable(True)` as in the `onnx_export.py` script.
|
319 |
+
2. TF ported models with 'SAME' padding will have the padding fixed at export time to the resolution used for export. Even though dynamic padding is supported in opset >= 11, I can't get it working.
|
320 |
+
3. ONNX optimize facility doesn't work reliably in PyTorch 1.6 / ONNX 1.7. Fortunately, the onnxruntime based inference is working very well now and includes on the fly optimization.
|
321 |
+
3. ONNX / Caffe2 export/import frequently breaks with different PyTorch and ONNX version releases. Please check their respective issue trackers before filing issues here.
|
322 |
+
|
323 |
+
|
annotator/normalbae/models/submodules/efficientnet_repo/caffe2_benchmark.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Caffe2 validation script
|
2 |
+
|
3 |
+
This script runs Caffe2 benchmark on exported ONNX model.
|
4 |
+
It is a useful tool for reporting model FLOPS.
|
5 |
+
|
6 |
+
Copyright 2020 Ross Wightman
|
7 |
+
"""
|
8 |
+
import argparse
|
9 |
+
from caffe2.python import core, workspace, model_helper
|
10 |
+
from caffe2.proto import caffe2_pb2
|
11 |
+
|
12 |
+
|
13 |
+
parser = argparse.ArgumentParser(description='Caffe2 Model Benchmark')
|
14 |
+
parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME',
|
15 |
+
help='caffe2 model pb name prefix')
|
16 |
+
parser.add_argument('--c2-init', default='', type=str, metavar='PATH',
|
17 |
+
help='caffe2 model init .pb')
|
18 |
+
parser.add_argument('--c2-predict', default='', type=str, metavar='PATH',
|
19 |
+
help='caffe2 model predict .pb')
|
20 |
+
parser.add_argument('-b', '--batch-size', default=1, type=int,
|
21 |
+
metavar='N', help='mini-batch size (default: 1)')
|
22 |
+
parser.add_argument('--img-size', default=224, type=int,
|
23 |
+
metavar='N', help='Input image dimension, uses model default if empty')
|
24 |
+
|
25 |
+
|
26 |
+
def main():
|
27 |
+
args = parser.parse_args()
|
28 |
+
args.gpu_id = 0
|
29 |
+
if args.c2_prefix:
|
30 |
+
args.c2_init = args.c2_prefix + '.init.pb'
|
31 |
+
args.c2_predict = args.c2_prefix + '.predict.pb'
|
32 |
+
|
33 |
+
model = model_helper.ModelHelper(name="le_net", init_params=False)
|
34 |
+
|
35 |
+
# Bring in the init net from init_net.pb
|
36 |
+
init_net_proto = caffe2_pb2.NetDef()
|
37 |
+
with open(args.c2_init, "rb") as f:
|
38 |
+
init_net_proto.ParseFromString(f.read())
|
39 |
+
model.param_init_net = core.Net(init_net_proto)
|
40 |
+
|
41 |
+
# bring in the predict net from predict_net.pb
|
42 |
+
predict_net_proto = caffe2_pb2.NetDef()
|
43 |
+
with open(args.c2_predict, "rb") as f:
|
44 |
+
predict_net_proto.ParseFromString(f.read())
|
45 |
+
model.net = core.Net(predict_net_proto)
|
46 |
+
|
47 |
+
# CUDA performance not impressive
|
48 |
+
#device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id)
|
49 |
+
#model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
|
50 |
+
#model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
|
51 |
+
|
52 |
+
input_blob = model.net.external_inputs[0]
|
53 |
+
model.param_init_net.GaussianFill(
|
54 |
+
[],
|
55 |
+
input_blob.GetUnscopedName(),
|
56 |
+
shape=(args.batch_size, 3, args.img_size, args.img_size),
|
57 |
+
mean=0.0,
|
58 |
+
std=1.0)
|
59 |
+
workspace.RunNetOnce(model.param_init_net)
|
60 |
+
workspace.CreateNet(model.net, overwrite=True)
|
61 |
+
workspace.BenchmarkNet(model.net.Proto().name, 5, 20, True)
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == '__main__':
|
65 |
+
main()
|
annotator/normalbae/models/submodules/efficientnet_repo/caffe2_validate.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Caffe2 validation script
|
2 |
+
|
3 |
+
This script is created to verify exported ONNX models running in Caffe2
|
4 |
+
It utilizes the same PyTorch dataloader/processing pipeline for a
|
5 |
+
fair comparison against the originals.
|
6 |
+
|
7 |
+
Copyright 2020 Ross Wightman
|
8 |
+
"""
|
9 |
+
import argparse
|
10 |
+
import numpy as np
|
11 |
+
from caffe2.python import core, workspace, model_helper
|
12 |
+
from caffe2.proto import caffe2_pb2
|
13 |
+
from data import create_loader, resolve_data_config, Dataset
|
14 |
+
from utils import AverageMeter
|
15 |
+
import time
|
16 |
+
|
17 |
+
parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation')
|
18 |
+
parser.add_argument('data', metavar='DIR',
|
19 |
+
help='path to dataset')
|
20 |
+
parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME',
|
21 |
+
help='caffe2 model pb name prefix')
|
22 |
+
parser.add_argument('--c2-init', default='', type=str, metavar='PATH',
|
23 |
+
help='caffe2 model init .pb')
|
24 |
+
parser.add_argument('--c2-predict', default='', type=str, metavar='PATH',
|
25 |
+
help='caffe2 model predict .pb')
|
26 |
+
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
|
27 |
+
help='number of data loading workers (default: 2)')
|
28 |
+
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
29 |
+
metavar='N', help='mini-batch size (default: 256)')
|
30 |
+
parser.add_argument('--img-size', default=None, type=int,
|
31 |
+
metavar='N', help='Input image dimension, uses model default if empty')
|
32 |
+
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
33 |
+
help='Override mean pixel value of dataset')
|
34 |
+
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
35 |
+
help='Override std deviation of of dataset')
|
36 |
+
parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT',
|
37 |
+
help='Override default crop pct of 0.875')
|
38 |
+
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
39 |
+
help='Image resize interpolation type (overrides model)')
|
40 |
+
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
|
41 |
+
help='use tensorflow mnasnet preporcessing')
|
42 |
+
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
43 |
+
metavar='N', help='print frequency (default: 10)')
|
44 |
+
|
45 |
+
|
46 |
+
def main():
|
47 |
+
args = parser.parse_args()
|
48 |
+
args.gpu_id = 0
|
49 |
+
if args.c2_prefix:
|
50 |
+
args.c2_init = args.c2_prefix + '.init.pb'
|
51 |
+
args.c2_predict = args.c2_prefix + '.predict.pb'
|
52 |
+
|
53 |
+
model = model_helper.ModelHelper(name="validation_net", init_params=False)
|
54 |
+
|
55 |
+
# Bring in the init net from init_net.pb
|
56 |
+
init_net_proto = caffe2_pb2.NetDef()
|
57 |
+
with open(args.c2_init, "rb") as f:
|
58 |
+
init_net_proto.ParseFromString(f.read())
|
59 |
+
model.param_init_net = core.Net(init_net_proto)
|
60 |
+
|
61 |
+
# bring in the predict net from predict_net.pb
|
62 |
+
predict_net_proto = caffe2_pb2.NetDef()
|
63 |
+
with open(args.c2_predict, "rb") as f:
|
64 |
+
predict_net_proto.ParseFromString(f.read())
|
65 |
+
model.net = core.Net(predict_net_proto)
|
66 |
+
|
67 |
+
data_config = resolve_data_config(None, args)
|
68 |
+
loader = create_loader(
|
69 |
+
Dataset(args.data, load_bytes=args.tf_preprocessing),
|
70 |
+
input_size=data_config['input_size'],
|
71 |
+
batch_size=args.batch_size,
|
72 |
+
use_prefetcher=False,
|
73 |
+
interpolation=data_config['interpolation'],
|
74 |
+
mean=data_config['mean'],
|
75 |
+
std=data_config['std'],
|
76 |
+
num_workers=args.workers,
|
77 |
+
crop_pct=data_config['crop_pct'],
|
78 |
+
tensorflow_preprocessing=args.tf_preprocessing)
|
79 |
+
|
80 |
+
# this is so obvious, wonderful interface </sarcasm>
|
81 |
+
input_blob = model.net.external_inputs[0]
|
82 |
+
output_blob = model.net.external_outputs[0]
|
83 |
+
|
84 |
+
if True:
|
85 |
+
device_opts = None
|
86 |
+
else:
|
87 |
+
# CUDA is crashing, no idea why, awesome error message, give it a try for kicks
|
88 |
+
device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id)
|
89 |
+
model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
|
90 |
+
model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True)
|
91 |
+
|
92 |
+
model.param_init_net.GaussianFill(
|
93 |
+
[], input_blob.GetUnscopedName(),
|
94 |
+
shape=(1,) + data_config['input_size'], mean=0.0, std=1.0)
|
95 |
+
workspace.RunNetOnce(model.param_init_net)
|
96 |
+
workspace.CreateNet(model.net, overwrite=True)
|
97 |
+
|
98 |
+
batch_time = AverageMeter()
|
99 |
+
top1 = AverageMeter()
|
100 |
+
top5 = AverageMeter()
|
101 |
+
end = time.time()
|
102 |
+
for i, (input, target) in enumerate(loader):
|
103 |
+
# run the net and return prediction
|
104 |
+
caffe2_in = input.data.numpy()
|
105 |
+
workspace.FeedBlob(input_blob, caffe2_in, device_opts)
|
106 |
+
workspace.RunNet(model.net, num_iter=1)
|
107 |
+
output = workspace.FetchBlob(output_blob)
|
108 |
+
|
109 |
+
# measure accuracy and record loss
|
110 |
+
prec1, prec5 = accuracy_np(output.data, target.numpy())
|
111 |
+
top1.update(prec1.item(), input.size(0))
|
112 |
+
top5.update(prec5.item(), input.size(0))
|
113 |
+
|
114 |
+
# measure elapsed time
|
115 |
+
batch_time.update(time.time() - end)
|
116 |
+
end = time.time()
|
117 |
+
|
118 |
+
if i % args.print_freq == 0:
|
119 |
+
print('Test: [{0}/{1}]\t'
|
120 |
+
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t'
|
121 |
+
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
122 |
+
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
|
123 |
+
i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg,
|
124 |
+
ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5))
|
125 |
+
|
126 |
+
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
|
127 |
+
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
|
128 |
+
|
129 |
+
|
130 |
+
def accuracy_np(output, target):
|
131 |
+
max_indices = np.argsort(output, axis=1)[:, ::-1]
|
132 |
+
top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean()
|
133 |
+
top1 = 100 * np.equal(max_indices[:, 0], target).mean()
|
134 |
+
return top1, top5
|
135 |
+
|
136 |
+
|
137 |
+
if __name__ == '__main__':
|
138 |
+
main()
|
annotator/normalbae/models/submodules/efficientnet_repo/data/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .dataset import Dataset
|
2 |
+
from .transforms import *
|
3 |
+
from .loader import create_loader
|
annotator/normalbae/models/submodules/efficientnet_repo/data/dataset.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Quick n simple image folder dataset
|
2 |
+
|
3 |
+
Copyright 2020 Ross Wightman
|
4 |
+
"""
|
5 |
+
import torch.utils.data as data
|
6 |
+
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import torch
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
|
13 |
+
IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg']
|
14 |
+
|
15 |
+
|
16 |
+
def natural_key(string_):
|
17 |
+
"""See http://www.codinghorror.com/blog/archives/001018.html"""
|
18 |
+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
19 |
+
|
20 |
+
|
21 |
+
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
|
22 |
+
if class_to_idx is None:
|
23 |
+
class_to_idx = dict()
|
24 |
+
build_class_idx = True
|
25 |
+
else:
|
26 |
+
build_class_idx = False
|
27 |
+
labels = []
|
28 |
+
filenames = []
|
29 |
+
for root, subdirs, files in os.walk(folder, topdown=False):
|
30 |
+
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
|
31 |
+
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
|
32 |
+
if build_class_idx and not subdirs:
|
33 |
+
class_to_idx[label] = None
|
34 |
+
for f in files:
|
35 |
+
base, ext = os.path.splitext(f)
|
36 |
+
if ext.lower() in types:
|
37 |
+
filenames.append(os.path.join(root, f))
|
38 |
+
labels.append(label)
|
39 |
+
if build_class_idx:
|
40 |
+
classes = sorted(class_to_idx.keys(), key=natural_key)
|
41 |
+
for idx, c in enumerate(classes):
|
42 |
+
class_to_idx[c] = idx
|
43 |
+
images_and_targets = zip(filenames, [class_to_idx[l] for l in labels])
|
44 |
+
if sort:
|
45 |
+
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
|
46 |
+
if build_class_idx:
|
47 |
+
return images_and_targets, classes, class_to_idx
|
48 |
+
else:
|
49 |
+
return images_and_targets
|
50 |
+
|
51 |
+
|
52 |
+
class Dataset(data.Dataset):
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
root,
|
57 |
+
transform=None,
|
58 |
+
load_bytes=False):
|
59 |
+
|
60 |
+
imgs, _, _ = find_images_and_targets(root)
|
61 |
+
if len(imgs) == 0:
|
62 |
+
raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
|
63 |
+
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
64 |
+
self.root = root
|
65 |
+
self.imgs = imgs
|
66 |
+
self.transform = transform
|
67 |
+
self.load_bytes = load_bytes
|
68 |
+
|
69 |
+
def __getitem__(self, index):
|
70 |
+
path, target = self.imgs[index]
|
71 |
+
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
|
72 |
+
if self.transform is not None:
|
73 |
+
img = self.transform(img)
|
74 |
+
if target is None:
|
75 |
+
target = torch.zeros(1).long()
|
76 |
+
return img, target
|
77 |
+
|
78 |
+
def __len__(self):
|
79 |
+
return len(self.imgs)
|
80 |
+
|
81 |
+
def filenames(self, indices=[], basename=False):
|
82 |
+
if indices:
|
83 |
+
if basename:
|
84 |
+
return [os.path.basename(self.imgs[i][0]) for i in indices]
|
85 |
+
else:
|
86 |
+
return [self.imgs[i][0] for i in indices]
|
87 |
+
else:
|
88 |
+
if basename:
|
89 |
+
return [os.path.basename(x[0]) for x in self.imgs]
|
90 |
+
else:
|
91 |
+
return [x[0] for x in self.imgs]
|
annotator/normalbae/models/submodules/efficientnet_repo/data/loader.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Fast Collate, CUDA Prefetcher
|
2 |
+
|
3 |
+
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
|
4 |
+
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
|
5 |
+
|
6 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
7 |
+
"""
|
8 |
+
import torch
|
9 |
+
import torch.utils.data
|
10 |
+
from .transforms import *
|
11 |
+
|
12 |
+
|
13 |
+
def fast_collate(batch):
|
14 |
+
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
15 |
+
batch_size = len(targets)
|
16 |
+
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
17 |
+
for i in range(batch_size):
|
18 |
+
tensor[i] += torch.from_numpy(batch[i][0])
|
19 |
+
|
20 |
+
return tensor, targets
|
21 |
+
|
22 |
+
|
23 |
+
class PrefetchLoader:
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
loader,
|
27 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
28 |
+
std=IMAGENET_DEFAULT_STD):
|
29 |
+
self.loader = loader
|
30 |
+
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
31 |
+
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
32 |
+
|
33 |
+
def __iter__(self):
|
34 |
+
stream = torch.cuda.Stream()
|
35 |
+
first = True
|
36 |
+
|
37 |
+
for next_input, next_target in self.loader:
|
38 |
+
with torch.cuda.stream(stream):
|
39 |
+
next_input = next_input.cuda(non_blocking=True)
|
40 |
+
next_target = next_target.cuda(non_blocking=True)
|
41 |
+
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
42 |
+
|
43 |
+
if not first:
|
44 |
+
yield input, target
|
45 |
+
else:
|
46 |
+
first = False
|
47 |
+
|
48 |
+
torch.cuda.current_stream().wait_stream(stream)
|
49 |
+
input = next_input
|
50 |
+
target = next_target
|
51 |
+
|
52 |
+
yield input, target
|
53 |
+
|
54 |
+
def __len__(self):
|
55 |
+
return len(self.loader)
|
56 |
+
|
57 |
+
@property
|
58 |
+
def sampler(self):
|
59 |
+
return self.loader.sampler
|
60 |
+
|
61 |
+
|
62 |
+
def create_loader(
|
63 |
+
dataset,
|
64 |
+
input_size,
|
65 |
+
batch_size,
|
66 |
+
is_training=False,
|
67 |
+
use_prefetcher=True,
|
68 |
+
interpolation='bilinear',
|
69 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
70 |
+
std=IMAGENET_DEFAULT_STD,
|
71 |
+
num_workers=1,
|
72 |
+
crop_pct=None,
|
73 |
+
tensorflow_preprocessing=False
|
74 |
+
):
|
75 |
+
if isinstance(input_size, tuple):
|
76 |
+
img_size = input_size[-2:]
|
77 |
+
else:
|
78 |
+
img_size = input_size
|
79 |
+
|
80 |
+
if tensorflow_preprocessing and use_prefetcher:
|
81 |
+
from data.tf_preprocessing import TfPreprocessTransform
|
82 |
+
transform = TfPreprocessTransform(
|
83 |
+
is_training=is_training, size=img_size, interpolation=interpolation)
|
84 |
+
else:
|
85 |
+
transform = transforms_imagenet_eval(
|
86 |
+
img_size,
|
87 |
+
interpolation=interpolation,
|
88 |
+
use_prefetcher=use_prefetcher,
|
89 |
+
mean=mean,
|
90 |
+
std=std,
|
91 |
+
crop_pct=crop_pct)
|
92 |
+
|
93 |
+
dataset.transform = transform
|
94 |
+
|
95 |
+
loader = torch.utils.data.DataLoader(
|
96 |
+
dataset,
|
97 |
+
batch_size=batch_size,
|
98 |
+
shuffle=False,
|
99 |
+
num_workers=num_workers,
|
100 |
+
collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate,
|
101 |
+
)
|
102 |
+
if use_prefetcher:
|
103 |
+
loader = PrefetchLoader(
|
104 |
+
loader,
|
105 |
+
mean=mean,
|
106 |
+
std=std)
|
107 |
+
|
108 |
+
return loader
|
annotator/normalbae/models/submodules/efficientnet_repo/data/tf_preprocessing.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Tensorflow Preprocessing Adapter
|
2 |
+
|
3 |
+
Allows use of Tensorflow preprocessing pipeline in PyTorch Transform
|
4 |
+
|
5 |
+
Copyright of original Tensorflow code below.
|
6 |
+
|
7 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
8 |
+
"""
|
9 |
+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
10 |
+
#
|
11 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
12 |
+
# you may not use this file except in compliance with the License.
|
13 |
+
# You may obtain a copy of the License at
|
14 |
+
#
|
15 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
16 |
+
#
|
17 |
+
# Unless required by applicable law or agreed to in writing, software
|
18 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
19 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
20 |
+
# See the License for the specific language governing permissions and
|
21 |
+
# limitations under the License.
|
22 |
+
# ==============================================================================
|
23 |
+
from __future__ import absolute_import
|
24 |
+
from __future__ import division
|
25 |
+
from __future__ import print_function
|
26 |
+
|
27 |
+
import tensorflow as tf
|
28 |
+
import numpy as np
|
29 |
+
|
30 |
+
IMAGE_SIZE = 224
|
31 |
+
CROP_PADDING = 32
|
32 |
+
|
33 |
+
|
34 |
+
def distorted_bounding_box_crop(image_bytes,
|
35 |
+
bbox,
|
36 |
+
min_object_covered=0.1,
|
37 |
+
aspect_ratio_range=(0.75, 1.33),
|
38 |
+
area_range=(0.05, 1.0),
|
39 |
+
max_attempts=100,
|
40 |
+
scope=None):
|
41 |
+
"""Generates cropped_image using one of the bboxes randomly distorted.
|
42 |
+
|
43 |
+
See `tf.image.sample_distorted_bounding_box` for more documentation.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
image_bytes: `Tensor` of binary image data.
|
47 |
+
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
|
48 |
+
where each coordinate is [0, 1) and the coordinates are arranged
|
49 |
+
as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
|
50 |
+
image.
|
51 |
+
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
|
52 |
+
area of the image must contain at least this fraction of any bounding
|
53 |
+
box supplied.
|
54 |
+
aspect_ratio_range: An optional list of `float`s. The cropped area of the
|
55 |
+
image must have an aspect ratio = width / height within this range.
|
56 |
+
area_range: An optional list of `float`s. The cropped area of the image
|
57 |
+
must contain a fraction of the supplied image within in this range.
|
58 |
+
max_attempts: An optional `int`. Number of attempts at generating a cropped
|
59 |
+
region of the image of the specified constraints. After `max_attempts`
|
60 |
+
failures, return the entire image.
|
61 |
+
scope: Optional `str` for name scope.
|
62 |
+
Returns:
|
63 |
+
cropped image `Tensor`
|
64 |
+
"""
|
65 |
+
with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]):
|
66 |
+
shape = tf.image.extract_jpeg_shape(image_bytes)
|
67 |
+
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
|
68 |
+
shape,
|
69 |
+
bounding_boxes=bbox,
|
70 |
+
min_object_covered=min_object_covered,
|
71 |
+
aspect_ratio_range=aspect_ratio_range,
|
72 |
+
area_range=area_range,
|
73 |
+
max_attempts=max_attempts,
|
74 |
+
use_image_if_no_bounding_boxes=True)
|
75 |
+
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
|
76 |
+
|
77 |
+
# Crop the image to the specified bounding box.
|
78 |
+
offset_y, offset_x, _ = tf.unstack(bbox_begin)
|
79 |
+
target_height, target_width, _ = tf.unstack(bbox_size)
|
80 |
+
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
|
81 |
+
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
|
82 |
+
|
83 |
+
return image
|
84 |
+
|
85 |
+
|
86 |
+
def _at_least_x_are_equal(a, b, x):
|
87 |
+
"""At least `x` of `a` and `b` `Tensors` are equal."""
|
88 |
+
match = tf.equal(a, b)
|
89 |
+
match = tf.cast(match, tf.int32)
|
90 |
+
return tf.greater_equal(tf.reduce_sum(match), x)
|
91 |
+
|
92 |
+
|
93 |
+
def _decode_and_random_crop(image_bytes, image_size, resize_method):
|
94 |
+
"""Make a random crop of image_size."""
|
95 |
+
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
|
96 |
+
image = distorted_bounding_box_crop(
|
97 |
+
image_bytes,
|
98 |
+
bbox,
|
99 |
+
min_object_covered=0.1,
|
100 |
+
aspect_ratio_range=(3. / 4, 4. / 3.),
|
101 |
+
area_range=(0.08, 1.0),
|
102 |
+
max_attempts=10,
|
103 |
+
scope=None)
|
104 |
+
original_shape = tf.image.extract_jpeg_shape(image_bytes)
|
105 |
+
bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
|
106 |
+
|
107 |
+
image = tf.cond(
|
108 |
+
bad,
|
109 |
+
lambda: _decode_and_center_crop(image_bytes, image_size),
|
110 |
+
lambda: tf.image.resize([image], [image_size, image_size], resize_method)[0])
|
111 |
+
|
112 |
+
return image
|
113 |
+
|
114 |
+
|
115 |
+
def _decode_and_center_crop(image_bytes, image_size, resize_method):
|
116 |
+
"""Crops to center of image with padding then scales image_size."""
|
117 |
+
shape = tf.image.extract_jpeg_shape(image_bytes)
|
118 |
+
image_height = shape[0]
|
119 |
+
image_width = shape[1]
|
120 |
+
|
121 |
+
padded_center_crop_size = tf.cast(
|
122 |
+
((image_size / (image_size + CROP_PADDING)) *
|
123 |
+
tf.cast(tf.minimum(image_height, image_width), tf.float32)),
|
124 |
+
tf.int32)
|
125 |
+
|
126 |
+
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
|
127 |
+
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
|
128 |
+
crop_window = tf.stack([offset_height, offset_width,
|
129 |
+
padded_center_crop_size, padded_center_crop_size])
|
130 |
+
image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
|
131 |
+
image = tf.image.resize([image], [image_size, image_size], resize_method)[0]
|
132 |
+
|
133 |
+
return image
|
134 |
+
|
135 |
+
|
136 |
+
def _flip(image):
|
137 |
+
"""Random horizontal image flip."""
|
138 |
+
image = tf.image.random_flip_left_right(image)
|
139 |
+
return image
|
140 |
+
|
141 |
+
|
142 |
+
def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'):
|
143 |
+
"""Preprocesses the given image for evaluation.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
image_bytes: `Tensor` representing an image binary of arbitrary size.
|
147 |
+
use_bfloat16: `bool` for whether to use bfloat16.
|
148 |
+
image_size: image size.
|
149 |
+
interpolation: image interpolation method
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
A preprocessed image `Tensor`.
|
153 |
+
"""
|
154 |
+
resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
|
155 |
+
image = _decode_and_random_crop(image_bytes, image_size, resize_method)
|
156 |
+
image = _flip(image)
|
157 |
+
image = tf.reshape(image, [image_size, image_size, 3])
|
158 |
+
image = tf.image.convert_image_dtype(
|
159 |
+
image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
|
160 |
+
return image
|
161 |
+
|
162 |
+
|
163 |
+
def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'):
|
164 |
+
"""Preprocesses the given image for evaluation.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
image_bytes: `Tensor` representing an image binary of arbitrary size.
|
168 |
+
use_bfloat16: `bool` for whether to use bfloat16.
|
169 |
+
image_size: image size.
|
170 |
+
interpolation: image interpolation method
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
A preprocessed image `Tensor`.
|
174 |
+
"""
|
175 |
+
resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
|
176 |
+
image = _decode_and_center_crop(image_bytes, image_size, resize_method)
|
177 |
+
image = tf.reshape(image, [image_size, image_size, 3])
|
178 |
+
image = tf.image.convert_image_dtype(
|
179 |
+
image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
|
180 |
+
return image
|
181 |
+
|
182 |
+
|
183 |
+
def preprocess_image(image_bytes,
|
184 |
+
is_training=False,
|
185 |
+
use_bfloat16=False,
|
186 |
+
image_size=IMAGE_SIZE,
|
187 |
+
interpolation='bicubic'):
|
188 |
+
"""Preprocesses the given image.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
image_bytes: `Tensor` representing an image binary of arbitrary size.
|
192 |
+
is_training: `bool` for whether the preprocessing is for training.
|
193 |
+
use_bfloat16: `bool` for whether to use bfloat16.
|
194 |
+
image_size: image size.
|
195 |
+
interpolation: image interpolation method
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
A preprocessed image `Tensor` with value range of [0, 255].
|
199 |
+
"""
|
200 |
+
if is_training:
|
201 |
+
return preprocess_for_train(image_bytes, use_bfloat16, image_size, interpolation)
|
202 |
+
else:
|
203 |
+
return preprocess_for_eval(image_bytes, use_bfloat16, image_size, interpolation)
|
204 |
+
|
205 |
+
|
206 |
+
class TfPreprocessTransform:
|
207 |
+
|
208 |
+
def __init__(self, is_training=False, size=224, interpolation='bicubic'):
|
209 |
+
self.is_training = is_training
|
210 |
+
self.size = size[0] if isinstance(size, tuple) else size
|
211 |
+
self.interpolation = interpolation
|
212 |
+
self._image_bytes = None
|
213 |
+
self.process_image = self._build_tf_graph()
|
214 |
+
self.sess = None
|
215 |
+
|
216 |
+
def _build_tf_graph(self):
|
217 |
+
with tf.device('/cpu:0'):
|
218 |
+
self._image_bytes = tf.placeholder(
|
219 |
+
shape=[],
|
220 |
+
dtype=tf.string,
|
221 |
+
)
|
222 |
+
img = preprocess_image(
|
223 |
+
self._image_bytes, self.is_training, False, self.size, self.interpolation)
|
224 |
+
return img
|
225 |
+
|
226 |
+
def __call__(self, image_bytes):
|
227 |
+
if self.sess is None:
|
228 |
+
self.sess = tf.Session()
|
229 |
+
img = self.sess.run(self.process_image, feed_dict={self._image_bytes: image_bytes})
|
230 |
+
img = img.round().clip(0, 255).astype(np.uint8)
|
231 |
+
if img.ndim < 3:
|
232 |
+
img = np.expand_dims(img, axis=-1)
|
233 |
+
img = np.rollaxis(img, 2) # HWC to CHW
|
234 |
+
return img
|
annotator/normalbae/models/submodules/efficientnet_repo/data/transforms.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms
|
3 |
+
from PIL import Image
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
DEFAULT_CROP_PCT = 0.875
|
8 |
+
|
9 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
10 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
11 |
+
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
12 |
+
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
13 |
+
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
|
14 |
+
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
|
15 |
+
|
16 |
+
|
17 |
+
def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
18 |
+
new_config = {}
|
19 |
+
default_cfg = default_cfg
|
20 |
+
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
|
21 |
+
default_cfg = model.default_cfg
|
22 |
+
|
23 |
+
# Resolve input/image size
|
24 |
+
# FIXME grayscale/chans arg to use different # channels?
|
25 |
+
in_chans = 3
|
26 |
+
input_size = (in_chans, 224, 224)
|
27 |
+
if args.img_size is not None:
|
28 |
+
# FIXME support passing img_size as tuple, non-square
|
29 |
+
assert isinstance(args.img_size, int)
|
30 |
+
input_size = (in_chans, args.img_size, args.img_size)
|
31 |
+
elif 'input_size' in default_cfg:
|
32 |
+
input_size = default_cfg['input_size']
|
33 |
+
new_config['input_size'] = input_size
|
34 |
+
|
35 |
+
# resolve interpolation method
|
36 |
+
new_config['interpolation'] = 'bicubic'
|
37 |
+
if args.interpolation:
|
38 |
+
new_config['interpolation'] = args.interpolation
|
39 |
+
elif 'interpolation' in default_cfg:
|
40 |
+
new_config['interpolation'] = default_cfg['interpolation']
|
41 |
+
|
42 |
+
# resolve dataset + model mean for normalization
|
43 |
+
new_config['mean'] = IMAGENET_DEFAULT_MEAN
|
44 |
+
if args.mean is not None:
|
45 |
+
mean = tuple(args.mean)
|
46 |
+
if len(mean) == 1:
|
47 |
+
mean = tuple(list(mean) * in_chans)
|
48 |
+
else:
|
49 |
+
assert len(mean) == in_chans
|
50 |
+
new_config['mean'] = mean
|
51 |
+
elif 'mean' in default_cfg:
|
52 |
+
new_config['mean'] = default_cfg['mean']
|
53 |
+
|
54 |
+
# resolve dataset + model std deviation for normalization
|
55 |
+
new_config['std'] = IMAGENET_DEFAULT_STD
|
56 |
+
if args.std is not None:
|
57 |
+
std = tuple(args.std)
|
58 |
+
if len(std) == 1:
|
59 |
+
std = tuple(list(std) * in_chans)
|
60 |
+
else:
|
61 |
+
assert len(std) == in_chans
|
62 |
+
new_config['std'] = std
|
63 |
+
elif 'std' in default_cfg:
|
64 |
+
new_config['std'] = default_cfg['std']
|
65 |
+
|
66 |
+
# resolve default crop percentage
|
67 |
+
new_config['crop_pct'] = DEFAULT_CROP_PCT
|
68 |
+
if args.crop_pct is not None:
|
69 |
+
new_config['crop_pct'] = args.crop_pct
|
70 |
+
elif 'crop_pct' in default_cfg:
|
71 |
+
new_config['crop_pct'] = default_cfg['crop_pct']
|
72 |
+
|
73 |
+
if verbose:
|
74 |
+
print('Data processing configuration for current model + dataset:')
|
75 |
+
for n, v in new_config.items():
|
76 |
+
print('\t%s: %s' % (n, str(v)))
|
77 |
+
|
78 |
+
return new_config
|
79 |
+
|
80 |
+
|
81 |
+
class ToNumpy:
|
82 |
+
|
83 |
+
def __call__(self, pil_img):
|
84 |
+
np_img = np.array(pil_img, dtype=np.uint8)
|
85 |
+
if np_img.ndim < 3:
|
86 |
+
np_img = np.expand_dims(np_img, axis=-1)
|
87 |
+
np_img = np.rollaxis(np_img, 2) # HWC to CHW
|
88 |
+
return np_img
|
89 |
+
|
90 |
+
|
91 |
+
class ToTensor:
|
92 |
+
|
93 |
+
def __init__(self, dtype=torch.float32):
|
94 |
+
self.dtype = dtype
|
95 |
+
|
96 |
+
def __call__(self, pil_img):
|
97 |
+
np_img = np.array(pil_img, dtype=np.uint8)
|
98 |
+
if np_img.ndim < 3:
|
99 |
+
np_img = np.expand_dims(np_img, axis=-1)
|
100 |
+
np_img = np.rollaxis(np_img, 2) # HWC to CHW
|
101 |
+
return torch.from_numpy(np_img).to(dtype=self.dtype)
|
102 |
+
|
103 |
+
|
104 |
+
def _pil_interp(method):
|
105 |
+
if method == 'bicubic':
|
106 |
+
return Image.BICUBIC
|
107 |
+
elif method == 'lanczos':
|
108 |
+
return Image.LANCZOS
|
109 |
+
elif method == 'hamming':
|
110 |
+
return Image.HAMMING
|
111 |
+
else:
|
112 |
+
# default bilinear, do we want to allow nearest?
|
113 |
+
return Image.BILINEAR
|
114 |
+
|
115 |
+
|
116 |
+
def transforms_imagenet_eval(
|
117 |
+
img_size=224,
|
118 |
+
crop_pct=None,
|
119 |
+
interpolation='bilinear',
|
120 |
+
use_prefetcher=False,
|
121 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
122 |
+
std=IMAGENET_DEFAULT_STD):
|
123 |
+
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
124 |
+
|
125 |
+
if isinstance(img_size, tuple):
|
126 |
+
assert len(img_size) == 2
|
127 |
+
if img_size[-1] == img_size[-2]:
|
128 |
+
# fall-back to older behaviour so Resize scales to shortest edge if target is square
|
129 |
+
scale_size = int(math.floor(img_size[0] / crop_pct))
|
130 |
+
else:
|
131 |
+
scale_size = tuple([int(x / crop_pct) for x in img_size])
|
132 |
+
else:
|
133 |
+
scale_size = int(math.floor(img_size / crop_pct))
|
134 |
+
|
135 |
+
tfl = [
|
136 |
+
transforms.Resize(scale_size, _pil_interp(interpolation)),
|
137 |
+
transforms.CenterCrop(img_size),
|
138 |
+
]
|
139 |
+
if use_prefetcher:
|
140 |
+
# prefetcher and collate will handle tensor conversion and norm
|
141 |
+
tfl += [ToNumpy()]
|
142 |
+
else:
|
143 |
+
tfl += [
|
144 |
+
transforms.ToTensor(),
|
145 |
+
transforms.Normalize(
|
146 |
+
mean=torch.tensor(mean),
|
147 |
+
std=torch.tensor(std))
|
148 |
+
]
|
149 |
+
|
150 |
+
return transforms.Compose(tfl)
|
annotator/normalbae/models/submodules/efficientnet_repo/geffnet/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .gen_efficientnet import *
|
2 |
+
from .mobilenetv3 import *
|
3 |
+
from .model_factory import create_model
|
4 |
+
from .config import is_exportable, is_scriptable, set_exportable, set_scriptable
|
5 |
+
from .activations import *
|
annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/__init__.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from geffnet import config
|
2 |
+
from geffnet.activations.activations_me import *
|
3 |
+
from geffnet.activations.activations_jit import *
|
4 |
+
from geffnet.activations.activations import *
|
5 |
+
import torch
|
6 |
+
|
7 |
+
_has_silu = 'silu' in dir(torch.nn.functional)
|
8 |
+
|
9 |
+
_ACT_FN_DEFAULT = dict(
|
10 |
+
silu=F.silu if _has_silu else swish,
|
11 |
+
swish=F.silu if _has_silu else swish,
|
12 |
+
mish=mish,
|
13 |
+
relu=F.relu,
|
14 |
+
relu6=F.relu6,
|
15 |
+
sigmoid=sigmoid,
|
16 |
+
tanh=tanh,
|
17 |
+
hard_sigmoid=hard_sigmoid,
|
18 |
+
hard_swish=hard_swish,
|
19 |
+
)
|
20 |
+
|
21 |
+
_ACT_FN_JIT = dict(
|
22 |
+
silu=F.silu if _has_silu else swish_jit,
|
23 |
+
swish=F.silu if _has_silu else swish_jit,
|
24 |
+
mish=mish_jit,
|
25 |
+
)
|
26 |
+
|
27 |
+
_ACT_FN_ME = dict(
|
28 |
+
silu=F.silu if _has_silu else swish_me,
|
29 |
+
swish=F.silu if _has_silu else swish_me,
|
30 |
+
mish=mish_me,
|
31 |
+
hard_swish=hard_swish_me,
|
32 |
+
hard_sigmoid_jit=hard_sigmoid_me,
|
33 |
+
)
|
34 |
+
|
35 |
+
_ACT_LAYER_DEFAULT = dict(
|
36 |
+
silu=nn.SiLU if _has_silu else Swish,
|
37 |
+
swish=nn.SiLU if _has_silu else Swish,
|
38 |
+
mish=Mish,
|
39 |
+
relu=nn.ReLU,
|
40 |
+
relu6=nn.ReLU6,
|
41 |
+
sigmoid=Sigmoid,
|
42 |
+
tanh=Tanh,
|
43 |
+
hard_sigmoid=HardSigmoid,
|
44 |
+
hard_swish=HardSwish,
|
45 |
+
)
|
46 |
+
|
47 |
+
_ACT_LAYER_JIT = dict(
|
48 |
+
silu=nn.SiLU if _has_silu else SwishJit,
|
49 |
+
swish=nn.SiLU if _has_silu else SwishJit,
|
50 |
+
mish=MishJit,
|
51 |
+
)
|
52 |
+
|
53 |
+
_ACT_LAYER_ME = dict(
|
54 |
+
silu=nn.SiLU if _has_silu else SwishMe,
|
55 |
+
swish=nn.SiLU if _has_silu else SwishMe,
|
56 |
+
mish=MishMe,
|
57 |
+
hard_swish=HardSwishMe,
|
58 |
+
hard_sigmoid=HardSigmoidMe
|
59 |
+
)
|
60 |
+
|
61 |
+
_OVERRIDE_FN = dict()
|
62 |
+
_OVERRIDE_LAYER = dict()
|
63 |
+
|
64 |
+
|
65 |
+
def add_override_act_fn(name, fn):
|
66 |
+
global _OVERRIDE_FN
|
67 |
+
_OVERRIDE_FN[name] = fn
|
68 |
+
|
69 |
+
|
70 |
+
def update_override_act_fn(overrides):
|
71 |
+
assert isinstance(overrides, dict)
|
72 |
+
global _OVERRIDE_FN
|
73 |
+
_OVERRIDE_FN.update(overrides)
|
74 |
+
|
75 |
+
|
76 |
+
def clear_override_act_fn():
|
77 |
+
global _OVERRIDE_FN
|
78 |
+
_OVERRIDE_FN = dict()
|
79 |
+
|
80 |
+
|
81 |
+
def add_override_act_layer(name, fn):
|
82 |
+
_OVERRIDE_LAYER[name] = fn
|
83 |
+
|
84 |
+
|
85 |
+
def update_override_act_layer(overrides):
|
86 |
+
assert isinstance(overrides, dict)
|
87 |
+
global _OVERRIDE_LAYER
|
88 |
+
_OVERRIDE_LAYER.update(overrides)
|
89 |
+
|
90 |
+
|
91 |
+
def clear_override_act_layer():
|
92 |
+
global _OVERRIDE_LAYER
|
93 |
+
_OVERRIDE_LAYER = dict()
|
94 |
+
|
95 |
+
|
96 |
+
def get_act_fn(name='relu'):
|
97 |
+
""" Activation Function Factory
|
98 |
+
Fetching activation fns by name with this function allows export or torch script friendly
|
99 |
+
functions to be returned dynamically based on current config.
|
100 |
+
"""
|
101 |
+
if name in _OVERRIDE_FN:
|
102 |
+
return _OVERRIDE_FN[name]
|
103 |
+
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
|
104 |
+
if use_me and name in _ACT_FN_ME:
|
105 |
+
# If not exporting or scripting the model, first look for a memory optimized version
|
106 |
+
# activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin
|
107 |
+
return _ACT_FN_ME[name]
|
108 |
+
if config.is_exportable() and name in ('silu', 'swish'):
|
109 |
+
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
|
110 |
+
return swish
|
111 |
+
use_jit = not (config.is_exportable() or config.is_no_jit())
|
112 |
+
# NOTE: export tracing should work with jit scripted components, but I keep running into issues
|
113 |
+
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
|
114 |
+
return _ACT_FN_JIT[name]
|
115 |
+
return _ACT_FN_DEFAULT[name]
|
116 |
+
|
117 |
+
|
118 |
+
def get_act_layer(name='relu'):
|
119 |
+
""" Activation Layer Factory
|
120 |
+
Fetching activation layers by name with this function allows export or torch script friendly
|
121 |
+
functions to be returned dynamically based on current config.
|
122 |
+
"""
|
123 |
+
if name in _OVERRIDE_LAYER:
|
124 |
+
return _OVERRIDE_LAYER[name]
|
125 |
+
use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit())
|
126 |
+
if use_me and name in _ACT_LAYER_ME:
|
127 |
+
return _ACT_LAYER_ME[name]
|
128 |
+
if config.is_exportable() and name in ('silu', 'swish'):
|
129 |
+
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
|
130 |
+
return Swish
|
131 |
+
use_jit = not (config.is_exportable() or config.is_no_jit())
|
132 |
+
# NOTE: export tracing should work with jit scripted components, but I keep running into issues
|
133 |
+
if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
|
134 |
+
return _ACT_LAYER_JIT[name]
|
135 |
+
return _ACT_LAYER_DEFAULT[name]
|
136 |
+
|
137 |
+
|
annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/activations.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Activations
|
2 |
+
|
3 |
+
A collection of activations fn and modules with a common interface so that they can
|
4 |
+
easily be swapped. All have an `inplace` arg even if not used.
|
5 |
+
|
6 |
+
Copyright 2020 Ross Wightman
|
7 |
+
"""
|
8 |
+
from torch import nn as nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
|
12 |
+
def swish(x, inplace: bool = False):
|
13 |
+
"""Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
|
14 |
+
and also as Swish (https://arxiv.org/abs/1710.05941).
|
15 |
+
|
16 |
+
TODO Rename to SiLU with addition to PyTorch
|
17 |
+
"""
|
18 |
+
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
|
19 |
+
|
20 |
+
|
21 |
+
class Swish(nn.Module):
|
22 |
+
def __init__(self, inplace: bool = False):
|
23 |
+
super(Swish, self).__init__()
|
24 |
+
self.inplace = inplace
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
return swish(x, self.inplace)
|
28 |
+
|
29 |
+
|
30 |
+
def mish(x, inplace: bool = False):
|
31 |
+
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
32 |
+
"""
|
33 |
+
return x.mul(F.softplus(x).tanh())
|
34 |
+
|
35 |
+
|
36 |
+
class Mish(nn.Module):
|
37 |
+
def __init__(self, inplace: bool = False):
|
38 |
+
super(Mish, self).__init__()
|
39 |
+
self.inplace = inplace
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
return mish(x, self.inplace)
|
43 |
+
|
44 |
+
|
45 |
+
def sigmoid(x, inplace: bool = False):
|
46 |
+
return x.sigmoid_() if inplace else x.sigmoid()
|
47 |
+
|
48 |
+
|
49 |
+
# PyTorch has this, but not with a consistent inplace argmument interface
|
50 |
+
class Sigmoid(nn.Module):
|
51 |
+
def __init__(self, inplace: bool = False):
|
52 |
+
super(Sigmoid, self).__init__()
|
53 |
+
self.inplace = inplace
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
return x.sigmoid_() if self.inplace else x.sigmoid()
|
57 |
+
|
58 |
+
|
59 |
+
def tanh(x, inplace: bool = False):
|
60 |
+
return x.tanh_() if inplace else x.tanh()
|
61 |
+
|
62 |
+
|
63 |
+
# PyTorch has this, but not with a consistent inplace argmument interface
|
64 |
+
class Tanh(nn.Module):
|
65 |
+
def __init__(self, inplace: bool = False):
|
66 |
+
super(Tanh, self).__init__()
|
67 |
+
self.inplace = inplace
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
return x.tanh_() if self.inplace else x.tanh()
|
71 |
+
|
72 |
+
|
73 |
+
def hard_swish(x, inplace: bool = False):
|
74 |
+
inner = F.relu6(x + 3.).div_(6.)
|
75 |
+
return x.mul_(inner) if inplace else x.mul(inner)
|
76 |
+
|
77 |
+
|
78 |
+
class HardSwish(nn.Module):
|
79 |
+
def __init__(self, inplace: bool = False):
|
80 |
+
super(HardSwish, self).__init__()
|
81 |
+
self.inplace = inplace
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
return hard_swish(x, self.inplace)
|
85 |
+
|
86 |
+
|
87 |
+
def hard_sigmoid(x, inplace: bool = False):
|
88 |
+
if inplace:
|
89 |
+
return x.add_(3.).clamp_(0., 6.).div_(6.)
|
90 |
+
else:
|
91 |
+
return F.relu6(x + 3.) / 6.
|
92 |
+
|
93 |
+
|
94 |
+
class HardSigmoid(nn.Module):
|
95 |
+
def __init__(self, inplace: bool = False):
|
96 |
+
super(HardSigmoid, self).__init__()
|
97 |
+
self.inplace = inplace
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
return hard_sigmoid(x, self.inplace)
|
101 |
+
|
102 |
+
|
annotator/normalbae/models/submodules/efficientnet_repo/geffnet/activations/activations_jit.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" Activations (jit)
|
2 |
+
|
3 |
+
A collection of jit-scripted activations fn and modules with a common interface so that they can
|
4 |
+
easily be swapped. All have an `inplace` arg even if not used.
|
5 |
+
|
6 |
+
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
|
7 |
+
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
|
8 |
+
versions if they contain in-place ops.
|
9 |
+
|
10 |
+
Copyright 2020 Ross Wightman
|
11 |
+
"""
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn as nn
|
15 |
+
from torch.nn import functional as F
|
16 |
+
|
17 |
+
__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit',
|
18 |
+
'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit']
|
19 |
+
|
20 |
+
|
21 |
+
@torch.jit.script
|
22 |
+
def swish_jit(x, inplace: bool = False):
|
23 |
+
"""Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
|
24 |
+
and also as Swish (https://arxiv.org/abs/1710.05941).
|
25 |
+
|
26 |
+
TODO Rename to SiLU with addition to PyTorch
|
27 |
+
"""
|
28 |
+
return x.mul(x.sigmoid())
|
29 |
+
|
30 |
+
|
31 |
+
@torch.jit.script
|
32 |
+
def mish_jit(x, _inplace: bool = False):
|
33 |
+
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
34 |
+
"""
|
35 |
+
return x.mul(F.softplus(x).tanh())
|
36 |
+
|
37 |
+
|
38 |
+
class SwishJit(nn.Module):
|
39 |
+
def __init__(self, inplace: bool = False):
|
40 |
+
super(SwishJit, self).__init__()
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return swish_jit(x)
|
44 |
+
|
45 |
+
|
46 |
+
class MishJit(nn.Module):
|
47 |
+
def __init__(self, inplace: bool = False):
|
48 |
+
super(MishJit, self).__init__()
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
return mish_jit(x)
|
52 |
+
|
53 |
+
|
54 |
+
@torch.jit.script
|
55 |
+
def hard_sigmoid_jit(x, inplace: bool = False):
|
56 |
+
# return F.relu6(x + 3.) / 6.
|
57 |
+
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
58 |
+
|
59 |
+
|
60 |
+
class HardSigmoidJit(nn.Module):
|
61 |
+
def __init__(self, inplace: bool = False):
|
62 |
+
super(HardSigmoidJit, self).__init__()
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
return hard_sigmoid_jit(x)
|
66 |
+
|
67 |
+
|
68 |
+
@torch.jit.script
|
69 |
+
def hard_swish_jit(x, inplace: bool = False):
|
70 |
+
# return x * (F.relu6(x + 3.) / 6)
|
71 |
+
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
72 |
+
|
73 |
+
|
74 |
+
class HardSwishJit(nn.Module):
|
75 |
+
def __init__(self, inplace: bool = False):
|
76 |
+
super(HardSwishJit, self).__init__()
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
return hard_swish_jit(x)
|