added CHM classification
Browse files- CHMCorr.py +546 -0
- ExtractEmbedding.py +2 -2
- FeatureExtractors.py +391 -0
- Utils.py +323 -0
- app.py +46 -11
- common/evaluation.py +32 -0
- common/logger.py +117 -0
- examples/Red_Winged_Blackbird_0012_6015.jpg +0 -0
- examples/Red_Winged_Blackbird_0025_5342.jpg +0 -0
- examples/Yellow_Headed_Blackbird_0020_8549.jpg +0 -0
- examples/Yellow_Headed_Blackbird_0026_8545.jpg +0 -0
- examples/sample1.jpeg +0 -0
- examples/sample2.jpeg +0 -0
- model/base/backbone.py +136 -0
- model/base/chm.py +190 -0
- model/base/chm_kernel.py +66 -0
- model/base/correlation.py +68 -0
- model/base/geometry.py +133 -0
- model/chmlearner.py +52 -0
- model/chmnet.py +42 -0
- visualization.py +274 -0
CHMCorr.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CHM-Corr Classifier
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import pickle
|
5 |
+
import random
|
6 |
+
from itertools import product
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from torchvision.datasets import ImageFolder
|
13 |
+
from tqdm import tqdm
|
14 |
+
from common.evaluation import Evaluator
|
15 |
+
from model import chmnet
|
16 |
+
from model.base.geometry import Geometry
|
17 |
+
|
18 |
+
from Utils import (
|
19 |
+
CosineCustomDataset,
|
20 |
+
PairedLayer4Extractor,
|
21 |
+
compute_spatial_similarity,
|
22 |
+
generate_mask,
|
23 |
+
normalize_array,
|
24 |
+
get_transforms,
|
25 |
+
arg_topK,
|
26 |
+
)
|
27 |
+
|
28 |
+
# Setting the random seed
|
29 |
+
random.seed(42)
|
30 |
+
|
31 |
+
# Helper Function
|
32 |
+
to_np = lambda x: x.data.to("cpu").numpy()
|
33 |
+
|
34 |
+
# CHMNet Config
|
35 |
+
chm_args = dict(
|
36 |
+
{
|
37 |
+
"alpha": [0.05, 0.1],
|
38 |
+
"img_size": 240,
|
39 |
+
"ktype": "psi",
|
40 |
+
"load": "pas_psi.pt",
|
41 |
+
}
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
class CHMGridTransfer:
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
query_image,
|
49 |
+
support_set,
|
50 |
+
support_set_labels,
|
51 |
+
train_folder,
|
52 |
+
top_N,
|
53 |
+
top_K,
|
54 |
+
binarization_threshold,
|
55 |
+
chm_source_transform,
|
56 |
+
chm_target_transform,
|
57 |
+
cosine_source_transform,
|
58 |
+
cosine_target_transform,
|
59 |
+
batch_size=64,
|
60 |
+
):
|
61 |
+
self.N = top_N
|
62 |
+
self.K = top_K
|
63 |
+
self.BS = batch_size
|
64 |
+
|
65 |
+
self.chm_source_transform = chm_source_transform
|
66 |
+
self.chm_target_transform = chm_target_transform
|
67 |
+
self.cosine_source_transform = cosine_source_transform
|
68 |
+
self.cosine_target_transform = cosine_target_transform
|
69 |
+
|
70 |
+
self.source_embeddings = None
|
71 |
+
self.target_embeddings = None
|
72 |
+
self.correspondence_map = None
|
73 |
+
self.similarity_maps = None
|
74 |
+
self.reverse_similarity_maps = None
|
75 |
+
self.transferred_points = None
|
76 |
+
|
77 |
+
self.binarization_threshold = binarization_threshold
|
78 |
+
|
79 |
+
# UPDATE THIS
|
80 |
+
self.q = query_image
|
81 |
+
self.support_set = support_set
|
82 |
+
self.labels_ss = support_set_labels
|
83 |
+
|
84 |
+
def build(self):
|
85 |
+
# C.M.H
|
86 |
+
test_ds = CosineCustomDataset(
|
87 |
+
query_image=self.q,
|
88 |
+
supporting_set=self.support_set,
|
89 |
+
source_transform=self.chm_source_transform,
|
90 |
+
target_transform=self.chm_target_transform,
|
91 |
+
)
|
92 |
+
test_dl = DataLoader(test_ds, batch_size=self.BS, shuffle=False)
|
93 |
+
self.find_correspondences(test_dl)
|
94 |
+
|
95 |
+
# LAYER 4s
|
96 |
+
test_ds = CosineCustomDataset(
|
97 |
+
query_image=self.q,
|
98 |
+
supporting_set=self.support_set,
|
99 |
+
source_transform=self.cosine_source_transform,
|
100 |
+
target_transform=self.cosine_target_transform,
|
101 |
+
)
|
102 |
+
test_dl = DataLoader(test_ds, batch_size=self.BS, shuffle=False)
|
103 |
+
self.compute_embeddings(test_dl)
|
104 |
+
self.compute_similarity_map()
|
105 |
+
|
106 |
+
def find_correspondences(self, test_dl):
|
107 |
+
model = chmnet.CHMNet(chm_args["ktype"])
|
108 |
+
model.load_state_dict(
|
109 |
+
torch.load(chm_args["load"], map_location=torch.device("cpu"))
|
110 |
+
)
|
111 |
+
Evaluator.initialize(chm_args["alpha"])
|
112 |
+
Geometry.initialize(img_size=chm_args["img_size"])
|
113 |
+
|
114 |
+
grid_results = []
|
115 |
+
transferred_points = []
|
116 |
+
|
117 |
+
# FIXED GRID HARD CODED
|
118 |
+
fixed_src_grid_points = list(
|
119 |
+
product(
|
120 |
+
np.linspace(1 + 17, 240 - 17 - 1, 7),
|
121 |
+
np.linspace(1 + 17, 240 - 17 - 1, 7),
|
122 |
+
)
|
123 |
+
)
|
124 |
+
fixed_src_grid_points = np.asarray(fixed_src_grid_points, dtype=np.float64).T
|
125 |
+
|
126 |
+
with torch.no_grad():
|
127 |
+
model.eval()
|
128 |
+
for idx, batch in enumerate(tqdm(test_dl)):
|
129 |
+
|
130 |
+
keypoints = (
|
131 |
+
torch.tensor(fixed_src_grid_points)
|
132 |
+
.unsqueeze(0)
|
133 |
+
.repeat(batch["src_img"].shape[0], 1, 1)
|
134 |
+
)
|
135 |
+
n_pts = torch.tensor(
|
136 |
+
np.asarray(batch["src_img"].shape[0] * [49]), dtype=torch.long
|
137 |
+
)
|
138 |
+
|
139 |
+
corr_matrix = model(batch["src_img"], batch["trg_img"])
|
140 |
+
prd_kps = Geometry.transfer_kps(
|
141 |
+
corr_matrix, keypoints, n_pts, normalized=False
|
142 |
+
)
|
143 |
+
transferred_points.append(prd_kps.cpu().numpy())
|
144 |
+
for tgt_points in prd_kps:
|
145 |
+
tgt_grid = []
|
146 |
+
for x, y in zip(tgt_points[0], tgt_points[1]):
|
147 |
+
tgt_grid.append(
|
148 |
+
[int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)]
|
149 |
+
)
|
150 |
+
grid_results.append(tgt_grid)
|
151 |
+
|
152 |
+
self.correspondence_map = grid_results
|
153 |
+
self.transferred_points = np.vstack(transferred_points)
|
154 |
+
|
155 |
+
def compute_embeddings(self, test_dl):
|
156 |
+
paired_extractor = PairedLayer4Extractor()
|
157 |
+
|
158 |
+
source_embeddings = []
|
159 |
+
target_embeddings = []
|
160 |
+
|
161 |
+
with torch.no_grad():
|
162 |
+
for idx, batch in enumerate(test_dl):
|
163 |
+
s_e, t_e = paired_extractor((batch["src_img"], batch["trg_img"]))
|
164 |
+
|
165 |
+
source_embeddings.append(s_e)
|
166 |
+
target_embeddings.append(t_e)
|
167 |
+
|
168 |
+
# EMBEDDINGS
|
169 |
+
self.source_embeddings = torch.cat(source_embeddings, axis=0)
|
170 |
+
self.target_embeddings = torch.cat(target_embeddings, axis=0)
|
171 |
+
|
172 |
+
def compute_similarity_map(self):
|
173 |
+
CosSim = nn.CosineSimilarity(dim=0, eps=1e-6)
|
174 |
+
|
175 |
+
similarity_maps = []
|
176 |
+
rsimilarity_maps = []
|
177 |
+
|
178 |
+
grid = []
|
179 |
+
for i in range(7):
|
180 |
+
for j in range(7):
|
181 |
+
grid.append([i, j])
|
182 |
+
|
183 |
+
# Compute for all image pairs
|
184 |
+
for i in range(len(self.correspondence_map)):
|
185 |
+
cosine_map = np.zeros((7, 7))
|
186 |
+
reverse_cosine_map = np.zeros((7, 7))
|
187 |
+
|
188 |
+
# calculate cosine based on the chm corr. map
|
189 |
+
for S, T in zip(grid, self.correspondence_map[i]):
|
190 |
+
v1 = self.source_embeddings[i][:, S[0], S[1]]
|
191 |
+
v2 = self.target_embeddings[i][:, T[0], T[1]]
|
192 |
+
covalue = CosSim(v1, v2)
|
193 |
+
cosine_map[S[0], S[1]] = covalue
|
194 |
+
reverse_cosine_map[T[0], T[1]] = covalue
|
195 |
+
|
196 |
+
similarity_maps.append(cosine_map)
|
197 |
+
rsimilarity_maps.append(reverse_cosine_map)
|
198 |
+
|
199 |
+
self.similarity_maps = similarity_maps
|
200 |
+
self.reverse_similarity_maps = rsimilarity_maps
|
201 |
+
|
202 |
+
def compute_score_using_cc(self):
|
203 |
+
# CC MAPS
|
204 |
+
SIMS_source, SIMS_target = [], []
|
205 |
+
for i in range(len(self.source_embeddings)):
|
206 |
+
simA, simB = compute_spatial_similarity(
|
207 |
+
to_np(self.source_embeddings[i]), to_np(self.target_embeddings[i])
|
208 |
+
)
|
209 |
+
|
210 |
+
SIMS_source.append(simA)
|
211 |
+
SIMS_target.append(simB)
|
212 |
+
|
213 |
+
SIMS_source = np.stack(SIMS_source, axis=0)
|
214 |
+
# SIMS_target = np.stack(SIMS_target, axis=0)
|
215 |
+
|
216 |
+
top_cos_values = []
|
217 |
+
|
218 |
+
for i in range(len(self.similarity_maps)):
|
219 |
+
cosine_value = np.multiply(
|
220 |
+
self.similarity_maps[i],
|
221 |
+
generate_mask(
|
222 |
+
normalize_array(SIMS_source[i]), t=self.binarization_threshold
|
223 |
+
),
|
224 |
+
)
|
225 |
+
top_5_indicies = np.argsort(cosine_value.T.reshape(-1))[::-1][:5]
|
226 |
+
mean_of_top_5 = np.mean(
|
227 |
+
[cosine_value.T.reshape(-1)[x] for x in top_5_indicies]
|
228 |
+
)
|
229 |
+
top_cos_values.append(np.mean(mean_of_top_5))
|
230 |
+
|
231 |
+
return top_cos_values
|
232 |
+
|
233 |
+
def compute_score_using_custom_points(self, selected_keypoint_masks):
|
234 |
+
top_cos_values = []
|
235 |
+
|
236 |
+
for i in range(len(self.similarity_maps)):
|
237 |
+
cosine_value = np.multiply(self.similarity_maps[i], selected_keypoint_masks)
|
238 |
+
top_indicies = np.argsort(cosine_value.T.reshape(-1))[::-1]
|
239 |
+
mean_of_tops = np.mean(
|
240 |
+
[cosine_value.T.reshape(-1)[x] for x in top_indicies]
|
241 |
+
)
|
242 |
+
top_cos_values.append(np.mean(mean_of_tops))
|
243 |
+
|
244 |
+
return top_cos_values
|
245 |
+
|
246 |
+
def export(self):
|
247 |
+
storage = {
|
248 |
+
"N": self.N,
|
249 |
+
"K": self.K,
|
250 |
+
"source_embeddings": self.source_embeddings,
|
251 |
+
"target_embeddings": self.target_embeddings,
|
252 |
+
"correspondence_map": self.correspondence_map,
|
253 |
+
"similarity_maps": self.similarity_maps,
|
254 |
+
"T": self.binarization_threshold,
|
255 |
+
"query": self.q,
|
256 |
+
"support_set": self.support_set,
|
257 |
+
"labels_for_support_set": self.labels_ss,
|
258 |
+
"rsimilarity_maps": self.reverse_similarity_maps,
|
259 |
+
"transferred_points": self.transferred_points,
|
260 |
+
}
|
261 |
+
|
262 |
+
return ModifiableCHMResults(storage)
|
263 |
+
|
264 |
+
|
265 |
+
class ModifiableCHMResults:
|
266 |
+
def __init__(self, storage):
|
267 |
+
self.N = storage["N"]
|
268 |
+
self.K = storage["K"]
|
269 |
+
self.source_embeddings = storage["source_embeddings"]
|
270 |
+
self.target_embeddings = storage["target_embeddings"]
|
271 |
+
self.correspondence_map = storage["correspondence_map"]
|
272 |
+
self.similarity_maps = storage["similarity_maps"]
|
273 |
+
self.T = storage["T"]
|
274 |
+
self.q = storage["query"]
|
275 |
+
self.support_set = storage["support_set"]
|
276 |
+
self.labels_ss = storage["labels_for_support_set"]
|
277 |
+
self.rsimilarity_maps = storage["rsimilarity_maps"]
|
278 |
+
self.transferred_points = storage["transferred_points"]
|
279 |
+
self.similarity_maps_masked = None
|
280 |
+
self.SIMS_source = None
|
281 |
+
self.SIMS_target = None
|
282 |
+
self.masked_sim_values = []
|
283 |
+
self.top_cos_values = []
|
284 |
+
|
285 |
+
def compute_score_using_cc(self):
|
286 |
+
# CC MAPS
|
287 |
+
SIMS_source, SIMS_target = [], []
|
288 |
+
for i in range(len(self.source_embeddings)):
|
289 |
+
simA, simB = compute_spatial_similarity(
|
290 |
+
to_np(self.source_embeddings[i]), to_np(self.target_embeddings[i])
|
291 |
+
)
|
292 |
+
|
293 |
+
SIMS_source.append(simA)
|
294 |
+
SIMS_target.append(simB)
|
295 |
+
|
296 |
+
SIMS_source = np.stack(SIMS_source, axis=0)
|
297 |
+
SIMS_target = np.stack(SIMS_target, axis=0)
|
298 |
+
|
299 |
+
self.SIMS_source = SIMS_source
|
300 |
+
self.SIMS_target = SIMS_target
|
301 |
+
|
302 |
+
top_cos_values = []
|
303 |
+
|
304 |
+
for i in range(len(self.similarity_maps)):
|
305 |
+
masked_sim_values = np.multiply(
|
306 |
+
self.similarity_maps[i],
|
307 |
+
generate_mask(normalize_array(SIMS_source[i]), t=self.T),
|
308 |
+
)
|
309 |
+
self.masked_sim_values.append(masked_sim_values)
|
310 |
+
top_5_indicies = np.argsort(masked_sim_values.T.reshape(-1))[::-1][:5]
|
311 |
+
mean_of_top_5 = np.mean(
|
312 |
+
[masked_sim_values.T.reshape(-1)[x] for x in top_5_indicies]
|
313 |
+
)
|
314 |
+
top_cos_values.append(np.mean(mean_of_top_5))
|
315 |
+
|
316 |
+
self.top_cos_values = top_cos_values
|
317 |
+
|
318 |
+
return top_cos_values
|
319 |
+
|
320 |
+
def compute_score_using_custom_points(self, selected_keypoint_masks):
|
321 |
+
top_cos_values = []
|
322 |
+
similarity_maps_masked = []
|
323 |
+
|
324 |
+
for i in range(len(self.similarity_maps)):
|
325 |
+
cosine_value = np.multiply(self.similarity_maps[i], selected_keypoint_masks)
|
326 |
+
similarity_maps_masked.append(cosine_value)
|
327 |
+
top_indicies = np.argsort(cosine_value.T.reshape(-1))[::-1]
|
328 |
+
mean_of_tops = np.mean(
|
329 |
+
[cosine_value.T.reshape(-1)[x] for x in top_indicies]
|
330 |
+
)
|
331 |
+
top_cos_values.append(np.mean(mean_of_tops))
|
332 |
+
|
333 |
+
self.similarity_maps_masked = similarity_maps_masked
|
334 |
+
return top_cos_values
|
335 |
+
|
336 |
+
def predict_using_cc(self):
|
337 |
+
top_cos_values = self.compute_score_using_cc()
|
338 |
+
# Predict
|
339 |
+
prediction = np.argmax(
|
340 |
+
np.bincount(
|
341 |
+
[self.labels_ss[x] for x in np.argsort(top_cos_values)[::-1][: self.K]]
|
342 |
+
)
|
343 |
+
)
|
344 |
+
prediction_weight = np.max(
|
345 |
+
np.bincount(
|
346 |
+
[self.labels_ss[x] for x in np.argsort(top_cos_values)[::-1][: self.K]]
|
347 |
+
)
|
348 |
+
)
|
349 |
+
|
350 |
+
reranked_nns_idx = [x for x in np.argsort(top_cos_values)[::-1]]
|
351 |
+
reranked_nns_files = [self.support_set[x] for x in reranked_nns_idx]
|
352 |
+
|
353 |
+
topK_idx = [
|
354 |
+
x
|
355 |
+
for x in np.argsort(top_cos_values)[::-1]
|
356 |
+
if self.labels_ss[x] == prediction
|
357 |
+
]
|
358 |
+
topK_files = [self.support_set[x] for x in topK_idx]
|
359 |
+
topK_cmaps = [self.correspondence_map[x] for x in topK_idx]
|
360 |
+
topK_similarity_maps = [self.similarity_maps[x] for x in topK_idx]
|
361 |
+
topK_rsimilarity_maps = [self.rsimilarity_maps[x] for x in topK_idx]
|
362 |
+
topK_transfered_points = [self.transferred_points[x] for x in topK_idx]
|
363 |
+
predicted_folder_name = topK_files[0].split("/")[-2]
|
364 |
+
|
365 |
+
return (
|
366 |
+
topK_idx,
|
367 |
+
prediction,
|
368 |
+
predicted_folder_name,
|
369 |
+
prediction_weight,
|
370 |
+
topK_files[: self.K],
|
371 |
+
reranked_nns_files[: self.K],
|
372 |
+
topK_cmaps[: self.K],
|
373 |
+
topK_similarity_maps[: self.K],
|
374 |
+
topK_rsimilarity_maps[: self.K],
|
375 |
+
topK_transfered_points[: self.K],
|
376 |
+
)
|
377 |
+
|
378 |
+
def predict_custom_pairs(self, selected_keypoint_masks):
|
379 |
+
top_cos_values = self.compute_score_using_custom_points(selected_keypoint_masks)
|
380 |
+
|
381 |
+
# Predict
|
382 |
+
prediction = np.argmax(
|
383 |
+
np.bincount(
|
384 |
+
[self.labels_ss[x] for x in np.argsort(top_cos_values)[::-1][: self.K]]
|
385 |
+
)
|
386 |
+
)
|
387 |
+
prediction_weight = np.max(
|
388 |
+
np.bincount(
|
389 |
+
[self.labels_ss[x] for x in np.argsort(top_cos_values)[::-1][: self.K]]
|
390 |
+
)
|
391 |
+
)
|
392 |
+
|
393 |
+
reranked_nns_idx = [x for x in np.argsort(top_cos_values)[::-1]]
|
394 |
+
reranked_nns_files = [self.support_set[x] for x in reranked_nns_idx]
|
395 |
+
|
396 |
+
topK_idx = [
|
397 |
+
x
|
398 |
+
for x in np.argsort(top_cos_values)[::-1]
|
399 |
+
if self.labels_ss[x] == prediction
|
400 |
+
]
|
401 |
+
topK_files = [self.support_set[x] for x in topK_idx]
|
402 |
+
topK_cmaps = [self.correspondence_map[x] for x in topK_idx]
|
403 |
+
topK_similarity_maps = [self.similarity_maps[x] for x in topK_idx]
|
404 |
+
topK_rsimilarity_maps = [self.rsimilarity_maps[x] for x in topK_idx]
|
405 |
+
topK_transferred_points = [self.transferred_points[x] for x in topK_idx]
|
406 |
+
# topK_scores = [top_cos_values[x] for x in topK_idx]
|
407 |
+
topK_masked_sims = [self.similarity_maps_masked[x] for x in topK_idx]
|
408 |
+
predicted_folder_name = topK_files[0].split("/")[-2]
|
409 |
+
|
410 |
+
non_zero_mask = np.count_nonzero(selected_keypoint_masks)
|
411 |
+
|
412 |
+
return (
|
413 |
+
topK_idx,
|
414 |
+
prediction,
|
415 |
+
predicted_folder_name,
|
416 |
+
prediction_weight,
|
417 |
+
topK_files[: self.K],
|
418 |
+
reranked_nns_files[: self.K],
|
419 |
+
topK_cmaps[: self.K],
|
420 |
+
topK_similarity_maps[: self.K],
|
421 |
+
topK_rsimilarity_maps[: self.K],
|
422 |
+
topK_transferred_points[: self.K],
|
423 |
+
topK_masked_sims[: self.K],
|
424 |
+
non_zero_mask,
|
425 |
+
)
|
426 |
+
|
427 |
+
|
428 |
+
def export_visualizations_results(
|
429 |
+
reranker_output,
|
430 |
+
knn_predicted_label,
|
431 |
+
knn_confidence,
|
432 |
+
topK_knns,
|
433 |
+
K=20,
|
434 |
+
N=50,
|
435 |
+
T=0.55,
|
436 |
+
):
|
437 |
+
"""
|
438 |
+
Export all details for visualization and analysis
|
439 |
+
"""
|
440 |
+
|
441 |
+
non_zero_mask = 5 # default value
|
442 |
+
(
|
443 |
+
topK_idx,
|
444 |
+
p,
|
445 |
+
pfn,
|
446 |
+
pr,
|
447 |
+
rfiles,
|
448 |
+
reranked_nns,
|
449 |
+
cmaps,
|
450 |
+
sims,
|
451 |
+
rsims,
|
452 |
+
trns_kpts,
|
453 |
+
) = reranker_output.predict_using_cc()
|
454 |
+
|
455 |
+
MASKED_COSINE_VALUES = [
|
456 |
+
np.multiply(
|
457 |
+
sims[X],
|
458 |
+
generate_mask(
|
459 |
+
normalize_array(reranker_output.SIMS_source[topK_idx[X]]), t=T
|
460 |
+
),
|
461 |
+
)
|
462 |
+
for X in range(len(sims))
|
463 |
+
]
|
464 |
+
|
465 |
+
list_of_source_points = []
|
466 |
+
list_of_target_points = []
|
467 |
+
|
468 |
+
for CK in range(len(sims)):
|
469 |
+
target_keypoints = []
|
470 |
+
topk_index = arg_topK(MASKED_COSINE_VALUES[CK], topK=non_zero_mask)
|
471 |
+
|
472 |
+
for i in range(non_zero_mask): # Number of Connections
|
473 |
+
# Psource = point_list[topk_index[i]]
|
474 |
+
x, y = trns_kpts[CK].T[topk_index[i]]
|
475 |
+
Ptarget = int(((x + 1) / 2.0) * 240), int(((y + 1) / 2.0) * 240)
|
476 |
+
target_keypoints.append(Ptarget)
|
477 |
+
|
478 |
+
# Uniform Grid of points
|
479 |
+
a = np.linspace(1 + 17, 240 - 17 - 1, 7)
|
480 |
+
b = np.linspace(1 + 17, 240 - 17 - 1, 7)
|
481 |
+
point_list = list(product(a, b))
|
482 |
+
|
483 |
+
list_of_source_points.append(np.asarray([point_list[x] for x in topk_index]))
|
484 |
+
list_of_target_points.append(np.asarray(target_keypoints))
|
485 |
+
|
486 |
+
# EXPORT OUTPUT
|
487 |
+
detailed_output = {
|
488 |
+
"q": reranker_output.q,
|
489 |
+
"K": K,
|
490 |
+
"N": N,
|
491 |
+
"knn-prediction": knn_predicted_label,
|
492 |
+
"knn-prediction-confidence": knn_confidence,
|
493 |
+
"knn-nearest-neighbors": topK_knns,
|
494 |
+
"chm-prediction": pfn,
|
495 |
+
"chm-prediction-confidence": pr,
|
496 |
+
"chm-nearest-neighbors": rfiles,
|
497 |
+
"correspondance_map": cmaps,
|
498 |
+
"masked_cos_values": MASKED_COSINE_VALUES,
|
499 |
+
"src-keypoints": list_of_source_points,
|
500 |
+
"tgt-keypoints": list_of_target_points,
|
501 |
+
"non_zero_mask": non_zero_mask,
|
502 |
+
"transferred_kpoints": trns_kpts,
|
503 |
+
}
|
504 |
+
|
505 |
+
return detailed_output
|
506 |
+
|
507 |
+
|
508 |
+
def chm_classify_and_visualize(
|
509 |
+
query_image, kNN_results, support, TRAIN_SET, N=50, K=20, T=0.55, BS=64
|
510 |
+
):
|
511 |
+
global chm_args
|
512 |
+
chm_src_t, chm_tgt_t, cos_src_t, cos_tgt_t = get_transforms("single", chm_args)
|
513 |
+
knn_predicted_label, knn_confidence, topK_knns = kNN_results
|
514 |
+
|
515 |
+
reranker = CHMGridTransfer(
|
516 |
+
query_image=query_image,
|
517 |
+
support_set=support[0],
|
518 |
+
support_set_labels=support[1],
|
519 |
+
train_folder=TRAIN_SET,
|
520 |
+
top_N=N,
|
521 |
+
top_K=K,
|
522 |
+
binarization_threshold=T,
|
523 |
+
chm_source_transform=chm_src_t,
|
524 |
+
chm_target_transform=chm_tgt_t,
|
525 |
+
cosine_source_transform=cos_src_t,
|
526 |
+
cosine_target_transform=cos_tgt_t,
|
527 |
+
batch_size=BS,
|
528 |
+
)
|
529 |
+
|
530 |
+
# Building the reranker
|
531 |
+
reranker.build()
|
532 |
+
# Make a ModifiableCHMResults
|
533 |
+
exported_reranker = reranker.export()
|
534 |
+
# Export A details for visualizations
|
535 |
+
|
536 |
+
output = export_visualizations_results(
|
537 |
+
exported_reranker,
|
538 |
+
knn_predicted_label,
|
539 |
+
knn_confidence,
|
540 |
+
topK_knns,
|
541 |
+
K,
|
542 |
+
N,
|
543 |
+
T,
|
544 |
+
)
|
545 |
+
|
546 |
+
return output
|
ExtractEmbedding.py
CHANGED
@@ -36,7 +36,7 @@ class Wrapper(torch.nn.Module):
|
|
36 |
return "Wrappper"
|
37 |
|
38 |
|
39 |
-
def QueryToEmbedding(
|
40 |
dataset_transform = transforms.Compose(
|
41 |
[
|
42 |
transforms.Resize(256),
|
@@ -50,7 +50,7 @@ def QueryToEmbedding(query_pil):
|
|
50 |
model.eval()
|
51 |
myw = Wrapper(model)
|
52 |
|
53 |
-
|
54 |
query_pt = dataset_transform(query_pil)
|
55 |
|
56 |
with torch.no_grad():
|
|
|
36 |
return "Wrappper"
|
37 |
|
38 |
|
39 |
+
def QueryToEmbedding(query_path):
|
40 |
dataset_transform = transforms.Compose(
|
41 |
[
|
42 |
transforms.Resize(256),
|
|
|
50 |
model.eval()
|
51 |
myw = Wrapper(model)
|
52 |
|
53 |
+
query_pil = Image.open(query_path)
|
54 |
query_pt = dataset_transform(query_pil)
|
55 |
|
56 |
with torch.no_grad():
|
FeatureExtractors.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original Author: Jonathan Donnellya ([email protected])
|
2 |
+
# Modified by Mohammad Reza Taesiri ([email protected])
|
3 |
+
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
model_dir = os.path.dirname(os.path.realpath(__file__))
|
10 |
+
|
11 |
+
|
12 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
13 |
+
"""3x3 convolution with padding"""
|
14 |
+
return nn.Conv2d(
|
15 |
+
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
20 |
+
"""1x1 convolution"""
|
21 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
22 |
+
|
23 |
+
|
24 |
+
class BasicBlock(nn.Module):
|
25 |
+
# class attribute
|
26 |
+
expansion = 1
|
27 |
+
num_layers = 2
|
28 |
+
|
29 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
30 |
+
super(BasicBlock, self).__init__()
|
31 |
+
# only conv with possibly not 1 stride
|
32 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
33 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
34 |
+
self.relu = nn.ReLU(inplace=True)
|
35 |
+
self.conv2 = conv3x3(planes, planes)
|
36 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
37 |
+
|
38 |
+
# if stride is not 1 then self.downsample cannot be None
|
39 |
+
self.downsample = downsample
|
40 |
+
self.stride = stride
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
identity = x
|
44 |
+
|
45 |
+
out = self.conv1(x)
|
46 |
+
out = self.bn1(out)
|
47 |
+
out = self.relu(out)
|
48 |
+
|
49 |
+
out = self.conv2(out)
|
50 |
+
out = self.bn2(out)
|
51 |
+
|
52 |
+
if self.downsample is not None:
|
53 |
+
identity = self.downsample(x)
|
54 |
+
|
55 |
+
# the residual connection
|
56 |
+
out += identity
|
57 |
+
out = self.relu(out)
|
58 |
+
|
59 |
+
return out
|
60 |
+
|
61 |
+
def block_conv_info(self):
|
62 |
+
block_kernel_sizes = [3, 3]
|
63 |
+
block_strides = [self.stride, 1]
|
64 |
+
block_paddings = [1, 1]
|
65 |
+
|
66 |
+
return block_kernel_sizes, block_strides, block_paddings
|
67 |
+
|
68 |
+
|
69 |
+
class Bottleneck(nn.Module):
|
70 |
+
# class attribute
|
71 |
+
expansion = 4
|
72 |
+
num_layers = 3
|
73 |
+
|
74 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
75 |
+
super(Bottleneck, self).__init__()
|
76 |
+
self.conv1 = conv1x1(inplanes, planes)
|
77 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
78 |
+
# only conv with possibly not 1 stride
|
79 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
80 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
81 |
+
self.conv3 = conv1x1(planes, planes * self.expansion)
|
82 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
83 |
+
self.relu = nn.ReLU(inplace=True)
|
84 |
+
|
85 |
+
# if stride is not 1 then self.downsample cannot be None
|
86 |
+
self.downsample = downsample
|
87 |
+
self.stride = stride
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
identity = x
|
91 |
+
|
92 |
+
out = self.conv1(x)
|
93 |
+
out = self.bn1(out)
|
94 |
+
out = self.relu(out)
|
95 |
+
|
96 |
+
out = self.conv2(out)
|
97 |
+
out = self.bn2(out)
|
98 |
+
out = self.relu(out)
|
99 |
+
|
100 |
+
out = self.conv3(out)
|
101 |
+
out = self.bn3(out)
|
102 |
+
|
103 |
+
if self.downsample is not None:
|
104 |
+
identity = self.downsample(x)
|
105 |
+
|
106 |
+
out += identity
|
107 |
+
out = self.relu(out)
|
108 |
+
|
109 |
+
return out
|
110 |
+
|
111 |
+
def block_conv_info(self):
|
112 |
+
block_kernel_sizes = [1, 3, 1]
|
113 |
+
block_strides = [1, self.stride, 1]
|
114 |
+
block_paddings = [0, 1, 0]
|
115 |
+
|
116 |
+
return block_kernel_sizes, block_strides, block_paddings
|
117 |
+
|
118 |
+
|
119 |
+
class ResNet_features(nn.Module):
|
120 |
+
"""
|
121 |
+
the convolutional layers of ResNet
|
122 |
+
the average pooling and final fully convolutional layer is removed
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
|
126 |
+
super(ResNet_features, self).__init__()
|
127 |
+
|
128 |
+
self.inplanes = 64
|
129 |
+
|
130 |
+
# the first convolutional layer before the structured sequence of blocks
|
131 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
132 |
+
self.bn1 = nn.BatchNorm2d(64)
|
133 |
+
self.relu = nn.ReLU(inplace=True)
|
134 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
135 |
+
# comes from the first conv and the following max pool
|
136 |
+
self.kernel_sizes = [7, 3]
|
137 |
+
self.strides = [2, 2]
|
138 |
+
self.paddings = [3, 1]
|
139 |
+
|
140 |
+
# the following layers, each layer is a sequence of blocks
|
141 |
+
self.block = block
|
142 |
+
self.layers = layers
|
143 |
+
self.layer1 = self._make_layer(
|
144 |
+
block=block, planes=64, num_blocks=self.layers[0]
|
145 |
+
)
|
146 |
+
self.layer2 = self._make_layer(
|
147 |
+
block=block, planes=128, num_blocks=self.layers[1], stride=2
|
148 |
+
)
|
149 |
+
self.layer3 = self._make_layer(
|
150 |
+
block=block, planes=256, num_blocks=self.layers[2], stride=2
|
151 |
+
)
|
152 |
+
self.layer4 = self._make_layer(
|
153 |
+
block=block, planes=512, num_blocks=self.layers[3], stride=2
|
154 |
+
)
|
155 |
+
|
156 |
+
# initialize the parameters
|
157 |
+
for m in self.modules():
|
158 |
+
if isinstance(m, nn.Conv2d):
|
159 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
160 |
+
elif isinstance(m, nn.BatchNorm2d):
|
161 |
+
nn.init.constant_(m.weight, 1)
|
162 |
+
nn.init.constant_(m.bias, 0)
|
163 |
+
|
164 |
+
# Zero-initialize the last BN in each residual branch,
|
165 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
166 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
167 |
+
if zero_init_residual:
|
168 |
+
for m in self.modules():
|
169 |
+
if isinstance(m, Bottleneck):
|
170 |
+
nn.init.constant_(m.bn3.weight, 0)
|
171 |
+
elif isinstance(m, BasicBlock):
|
172 |
+
nn.init.constant_(m.bn2.weight, 0)
|
173 |
+
|
174 |
+
def _make_layer(self, block, planes, num_blocks, stride=1):
|
175 |
+
downsample = None
|
176 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
177 |
+
downsample = nn.Sequential(
|
178 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
179 |
+
nn.BatchNorm2d(planes * block.expansion),
|
180 |
+
)
|
181 |
+
|
182 |
+
layers = []
|
183 |
+
# only the first block has downsample that is possibly not None
|
184 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
185 |
+
|
186 |
+
self.inplanes = planes * block.expansion
|
187 |
+
for _ in range(1, num_blocks):
|
188 |
+
layers.append(block(self.inplanes, planes))
|
189 |
+
|
190 |
+
# keep track of every block's conv size, stride size, and padding size
|
191 |
+
for each_block in layers:
|
192 |
+
(
|
193 |
+
block_kernel_sizes,
|
194 |
+
block_strides,
|
195 |
+
block_paddings,
|
196 |
+
) = each_block.block_conv_info()
|
197 |
+
self.kernel_sizes.extend(block_kernel_sizes)
|
198 |
+
self.strides.extend(block_strides)
|
199 |
+
self.paddings.extend(block_paddings)
|
200 |
+
|
201 |
+
return nn.Sequential(*layers)
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
x = self.conv1(x)
|
205 |
+
x = self.bn1(x)
|
206 |
+
x = self.relu(x)
|
207 |
+
x = self.maxpool(x)
|
208 |
+
|
209 |
+
x = self.layer1(x)
|
210 |
+
x = self.layer2(x)
|
211 |
+
x = self.layer3(x)
|
212 |
+
x = self.layer4(x)
|
213 |
+
|
214 |
+
return x
|
215 |
+
|
216 |
+
def conv_info(self):
|
217 |
+
return self.kernel_sizes, self.strides, self.paddings
|
218 |
+
|
219 |
+
def num_layers(self):
|
220 |
+
"""
|
221 |
+
the number of conv layers in the network, not counting the number
|
222 |
+
of bypass layers
|
223 |
+
"""
|
224 |
+
|
225 |
+
return (
|
226 |
+
self.block.num_layers * self.layers[0]
|
227 |
+
+ self.block.num_layers * self.layers[1]
|
228 |
+
+ self.block.num_layers * self.layers[2]
|
229 |
+
+ self.block.num_layers * self.layers[3]
|
230 |
+
+ 1
|
231 |
+
)
|
232 |
+
|
233 |
+
def __repr__(self):
|
234 |
+
template = "resnet{}_features"
|
235 |
+
return template.format(self.num_layers() + 1)
|
236 |
+
|
237 |
+
|
238 |
+
def resnet50_features(pretrained=True, inat=True, **kwargs):
|
239 |
+
"""Constructs a ResNet-50 model.
|
240 |
+
Args:
|
241 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet or iNaturalist
|
242 |
+
pretrained (bool): If True, returns a model pre-trained on iNaturalst; else, ImageNet
|
243 |
+
"""
|
244 |
+
model = ResNet_features(Bottleneck, [3, 4, 6, 4], **kwargs)
|
245 |
+
if pretrained:
|
246 |
+
if inat:
|
247 |
+
# print('Loading iNat model')
|
248 |
+
model_dict = torch.load(
|
249 |
+
model_dir
|
250 |
+
+ "/../../weights/"
|
251 |
+
+ "BBN.iNaturalist2017.res50.90epoch.best_model.pth.pt"
|
252 |
+
)
|
253 |
+
else:
|
254 |
+
raise
|
255 |
+
|
256 |
+
if inat:
|
257 |
+
model_dict.pop("module.classifier.weight")
|
258 |
+
model_dict.pop("module.classifier.bias")
|
259 |
+
for key in list(model_dict.keys()):
|
260 |
+
model_dict[
|
261 |
+
key.replace("module.backbone.", "")
|
262 |
+
.replace("cb_block", "layer4.2")
|
263 |
+
.replace("rb_block", "layer4.3")
|
264 |
+
] = model_dict.pop(key)
|
265 |
+
|
266 |
+
else:
|
267 |
+
raise
|
268 |
+
|
269 |
+
model.load_state_dict(model_dict, strict=False)
|
270 |
+
|
271 |
+
return model
|
272 |
+
|
273 |
+
|
274 |
+
class ResNet_classifier(nn.Module):
|
275 |
+
"""
|
276 |
+
A classifier for Deformable ProtoPNet
|
277 |
+
"""
|
278 |
+
|
279 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
|
280 |
+
super(ResNet_classifier, self).__init__()
|
281 |
+
|
282 |
+
self.inplanes = 64
|
283 |
+
|
284 |
+
# the first convolutional layer before the structured sequence of blocks
|
285 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
286 |
+
self.bn1 = nn.BatchNorm2d(64)
|
287 |
+
self.relu = nn.ReLU(inplace=True)
|
288 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
289 |
+
# comes from the first conv and the following max pool
|
290 |
+
self.kernel_sizes = [7, 3]
|
291 |
+
self.strides = [2, 2]
|
292 |
+
self.paddings = [3, 1]
|
293 |
+
|
294 |
+
# the following layers, each layer is a sequence of blocks
|
295 |
+
self.block = block
|
296 |
+
self.layers = layers
|
297 |
+
self.layer1 = self._make_layer(
|
298 |
+
block=block, planes=64, num_blocks=self.layers[0]
|
299 |
+
)
|
300 |
+
self.layer2 = self._make_layer(
|
301 |
+
block=block, planes=128, num_blocks=self.layers[1], stride=2
|
302 |
+
)
|
303 |
+
self.layer3 = self._make_layer(
|
304 |
+
block=block, planes=256, num_blocks=self.layers[2], stride=2
|
305 |
+
)
|
306 |
+
self.layer4 = self._make_layer(
|
307 |
+
block=block, planes=512, num_blocks=self.layers[3], stride=2
|
308 |
+
)
|
309 |
+
|
310 |
+
self.classifier = nn.Linear(2048 * 7 * 7, 200)
|
311 |
+
|
312 |
+
# initialize the parameters
|
313 |
+
for m in self.modules():
|
314 |
+
if isinstance(m, nn.Conv2d):
|
315 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
316 |
+
elif isinstance(m, nn.BatchNorm2d):
|
317 |
+
nn.init.constant_(m.weight, 1)
|
318 |
+
nn.init.constant_(m.bias, 0)
|
319 |
+
|
320 |
+
# Zero-initialize the last BN in each residual branch,
|
321 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
322 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
323 |
+
if zero_init_residual:
|
324 |
+
for m in self.modules():
|
325 |
+
if isinstance(m, Bottleneck):
|
326 |
+
nn.init.constant_(m.bn3.weight, 0)
|
327 |
+
elif isinstance(m, BasicBlock):
|
328 |
+
nn.init.constant_(m.bn2.weight, 0)
|
329 |
+
|
330 |
+
def _make_layer(self, block, planes, num_blocks, stride=1):
|
331 |
+
downsample = None
|
332 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
333 |
+
downsample = nn.Sequential(
|
334 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
335 |
+
nn.BatchNorm2d(planes * block.expansion),
|
336 |
+
)
|
337 |
+
|
338 |
+
layers = []
|
339 |
+
# only the first block has downsample that is possibly not None
|
340 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
341 |
+
|
342 |
+
self.inplanes = planes * block.expansion
|
343 |
+
for _ in range(1, num_blocks):
|
344 |
+
layers.append(block(self.inplanes, planes))
|
345 |
+
|
346 |
+
# keep track of every block's conv size, stride size, and padding size
|
347 |
+
for each_block in layers:
|
348 |
+
(
|
349 |
+
block_kernel_sizes,
|
350 |
+
block_strides,
|
351 |
+
block_paddings,
|
352 |
+
) = each_block.block_conv_info()
|
353 |
+
self.kernel_sizes.extend(block_kernel_sizes)
|
354 |
+
self.strides.extend(block_strides)
|
355 |
+
self.paddings.extend(block_paddings)
|
356 |
+
|
357 |
+
return nn.Sequential(*layers)
|
358 |
+
|
359 |
+
def forward(self, x):
|
360 |
+
x = self.conv1(x)
|
361 |
+
x = self.bn1(x)
|
362 |
+
x = self.relu(x)
|
363 |
+
x = self.maxpool(x)
|
364 |
+
|
365 |
+
x = self.layer1(x)
|
366 |
+
x = self.layer2(x)
|
367 |
+
x = self.layer3(x)
|
368 |
+
x = self.layer4(x)
|
369 |
+
x = self.classifier(torch.flatten(x, start_dim=1))
|
370 |
+
return x
|
371 |
+
|
372 |
+
def conv_info(self):
|
373 |
+
return self.kernel_sizes, self.strides, self.paddings
|
374 |
+
|
375 |
+
def num_layers(self):
|
376 |
+
"""
|
377 |
+
the number of conv layers in the network, not counting the number
|
378 |
+
of bypass layers
|
379 |
+
"""
|
380 |
+
|
381 |
+
return (
|
382 |
+
self.block.num_layers * self.layers[0]
|
383 |
+
+ self.block.num_layers * self.layers[1]
|
384 |
+
+ self.block.num_layers * self.layers[2]
|
385 |
+
+ self.block.num_layers * self.layers[3]
|
386 |
+
+ 1
|
387 |
+
)
|
388 |
+
|
389 |
+
def __repr__(self):
|
390 |
+
template = "resnet{}_features"
|
391 |
+
return template.format(self.num_layers() + 1)
|
Utils.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torchvision.models as models
|
4 |
+
from numpy import matlib as mb
|
5 |
+
from PIL import Image
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from torchvision.datasets import ImageFolder
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from FeatureExtractors import resnet50_features
|
10 |
+
|
11 |
+
to_np = lambda x: x.data.to("cpu").numpy()
|
12 |
+
|
13 |
+
|
14 |
+
def compute_spatial_similarity(conv1, conv2):
|
15 |
+
"""
|
16 |
+
Takes in the last convolutional layer from two images, computes the pooled output
|
17 |
+
feature, and then generates the spatial similarity map for both images.
|
18 |
+
"""
|
19 |
+
conv1 = conv1.reshape(-1, 7 * 7).T
|
20 |
+
conv2 = conv2.reshape(-1, 7 * 7).T
|
21 |
+
|
22 |
+
pool1 = np.mean(conv1, axis=0)
|
23 |
+
pool2 = np.mean(conv2, axis=0)
|
24 |
+
out_sz = (int(np.sqrt(conv1.shape[0])), int(np.sqrt(conv1.shape[0])))
|
25 |
+
conv1_normed = conv1 / np.linalg.norm(pool1) / conv1.shape[0]
|
26 |
+
conv2_normed = conv2 / np.linalg.norm(pool2) / conv2.shape[0]
|
27 |
+
im_similarity = np.zeros((conv1_normed.shape[0], conv1_normed.shape[0]))
|
28 |
+
|
29 |
+
for zz in range(conv1_normed.shape[0]):
|
30 |
+
repPx = mb.repmat(conv1_normed[zz, :], conv1_normed.shape[0], 1)
|
31 |
+
im_similarity[zz, :] = np.multiply(repPx, conv2_normed).sum(axis=1)
|
32 |
+
similarity1 = np.reshape(np.sum(im_similarity, axis=1), out_sz)
|
33 |
+
similarity2 = np.reshape(np.sum(im_similarity, axis=0), out_sz)
|
34 |
+
return similarity1, similarity2
|
35 |
+
|
36 |
+
|
37 |
+
def normalize_array(x):
|
38 |
+
x = np.asarray(x).copy()
|
39 |
+
x -= np.min(x)
|
40 |
+
x /= np.max(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def apply_threshold(x, t):
|
45 |
+
x = np.asarray(x).copy()
|
46 |
+
x[x < t] = 0
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
def generate_mask(x, t):
|
51 |
+
v = np.zeros_like(x)
|
52 |
+
v[x >= t] = 1
|
53 |
+
return v
|
54 |
+
|
55 |
+
|
56 |
+
def get_transforms(args_transform, chm_args):
|
57 |
+
# TRANSFORMS
|
58 |
+
cosine_transform_target = transforms.Compose(
|
59 |
+
[
|
60 |
+
transforms.Resize(256),
|
61 |
+
transforms.CenterCrop(224),
|
62 |
+
transforms.ToTensor(),
|
63 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
64 |
+
]
|
65 |
+
)
|
66 |
+
|
67 |
+
chm_transform_target = transforms.Compose(
|
68 |
+
[
|
69 |
+
transforms.Resize(chm_args["img_size"]),
|
70 |
+
transforms.CenterCrop((chm_args["img_size"], chm_args["img_size"])),
|
71 |
+
transforms.ToTensor(),
|
72 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
73 |
+
]
|
74 |
+
)
|
75 |
+
|
76 |
+
if args_transform == "multi":
|
77 |
+
cosine_transform_source = transforms.Compose(
|
78 |
+
[
|
79 |
+
transforms.Resize((224, 224)),
|
80 |
+
transforms.ToTensor(),
|
81 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
82 |
+
]
|
83 |
+
)
|
84 |
+
|
85 |
+
chm_transform_source = transforms.Compose(
|
86 |
+
[
|
87 |
+
transforms.Resize((chm_args["img_size"], chm_args["img_size"])),
|
88 |
+
transforms.ToTensor(),
|
89 |
+
transforms.Normalize(
|
90 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
91 |
+
),
|
92 |
+
]
|
93 |
+
)
|
94 |
+
|
95 |
+
elif args_transform == "single":
|
96 |
+
cosine_transform_source = transforms.Compose(
|
97 |
+
[
|
98 |
+
transforms.Resize(chm_args["img_size"]),
|
99 |
+
transforms.CenterCrop((chm_args["img_size"], chm_args["img_size"])),
|
100 |
+
transforms.Resize((224, 224)),
|
101 |
+
transforms.ToTensor(),
|
102 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
103 |
+
]
|
104 |
+
)
|
105 |
+
|
106 |
+
chm_transform_source = transforms.Compose(
|
107 |
+
[
|
108 |
+
transforms.Resize(chm_args["img_size"]),
|
109 |
+
transforms.CenterCrop((chm_args["img_size"], chm_args["img_size"])),
|
110 |
+
transforms.ToTensor(),
|
111 |
+
transforms.Normalize(
|
112 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
113 |
+
),
|
114 |
+
]
|
115 |
+
)
|
116 |
+
|
117 |
+
return (
|
118 |
+
chm_transform_source,
|
119 |
+
chm_transform_target,
|
120 |
+
cosine_transform_source,
|
121 |
+
cosine_transform_target,
|
122 |
+
)
|
123 |
+
|
124 |
+
|
125 |
+
def clamp(x, min_value, max_value):
|
126 |
+
return max(min_value, min(x, max_value))
|
127 |
+
|
128 |
+
|
129 |
+
def keep_top5(input_array, K=5):
|
130 |
+
top_5 = np.sort(input_array.reshape(-1))[::-1][K - 1]
|
131 |
+
masked = np.zeros_like(input_array)
|
132 |
+
masked[input_array >= top_5] = 1
|
133 |
+
return masked
|
134 |
+
|
135 |
+
|
136 |
+
def arg_topK(input_array, topK=5):
|
137 |
+
return np.argsort(input_array.T.reshape(-1))[::-1][:topK]
|
138 |
+
|
139 |
+
|
140 |
+
class KNNSupportSet:
|
141 |
+
def __init__(self, train_folder, val_folder, knn_scores, custom_val_labels=None):
|
142 |
+
self.train_data = ImageFolder(root=train_folder)
|
143 |
+
self.val_data = ImageFolder(root=val_folder)
|
144 |
+
self.knn_scores = knn_scores
|
145 |
+
|
146 |
+
if custom_val_labels is None:
|
147 |
+
self.val_labels = np.asarray([x[1] for x in self.val_data.imgs])
|
148 |
+
else:
|
149 |
+
self.val_labels = custom_val_labels
|
150 |
+
|
151 |
+
self.train_labels = np.asarray([x[1] for x in self.train_data.imgs])
|
152 |
+
|
153 |
+
def get_knn_predictions(self, k=20):
|
154 |
+
knn_predictions = [
|
155 |
+
np.argmax(np.bincount(self.train_labels[self.knn_scores[I][::-1][:k]]))
|
156 |
+
for I in range(len(self.knn_scores))
|
157 |
+
]
|
158 |
+
knn_accuracy = (
|
159 |
+
100
|
160 |
+
* np.sum((np.asarray(knn_predictions) == self.val_labels))
|
161 |
+
/ len(self.val_labels)
|
162 |
+
)
|
163 |
+
return knn_predictions, knn_accuracy
|
164 |
+
|
165 |
+
def get_support_set(self, selected_index, top_N=20):
|
166 |
+
support_set = self.knn_scores[selected_index][-top_N:][::-1]
|
167 |
+
return [self.train_data.imgs[x][0] for x in support_set]
|
168 |
+
|
169 |
+
def get_support_set_labels(self, selected_index, top_N=20):
|
170 |
+
support_set = self.knn_scores[selected_index][-top_N:][::-1]
|
171 |
+
return [self.train_data.imgs[x][1] for x in support_set]
|
172 |
+
|
173 |
+
def get_image_and_label_by_id(self, q_id):
|
174 |
+
q = self.val_data.imgs[q_id][0]
|
175 |
+
ql = self.val_data.imgs[q_id][1]
|
176 |
+
return (q, ql)
|
177 |
+
|
178 |
+
def get_folder_name(self, q_id):
|
179 |
+
q = self.val_data.imgs[q_id][0]
|
180 |
+
return q.split("/")[-2]
|
181 |
+
|
182 |
+
def get_top5_knn(self, query_id, k=20):
|
183 |
+
knn_pred, knn_acc = self.get_knn_predictions(k=k)
|
184 |
+
top_5s_index = np.where(
|
185 |
+
np.equal(
|
186 |
+
self.train_labels[self.knn_scores[query_id][::-1]], knn_pred[query_id]
|
187 |
+
)
|
188 |
+
)[0][:5]
|
189 |
+
top_5s = self.knn_scores[query_id][::-1][top_5s_index]
|
190 |
+
top_5s_files = [self.train_data.imgs[x][0] for x in top_5s]
|
191 |
+
return top_5s_files
|
192 |
+
|
193 |
+
def get_topK_knn(self, query_id, k=20):
|
194 |
+
knn_pred, knn_acc = self.get_knn_predictions(k=k)
|
195 |
+
top_ks_index = np.where(
|
196 |
+
np.equal(
|
197 |
+
self.train_labels[self.knn_scores[query_id][::-1]], knn_pred[query_id]
|
198 |
+
)
|
199 |
+
)[0][:k]
|
200 |
+
top_ks = self.knn_scores[query_id][::-1][top_ks_index]
|
201 |
+
top_ks_files = [self.train_data.imgs[x][0] for x in top_ks]
|
202 |
+
return top_ks_files
|
203 |
+
|
204 |
+
def get_foldername_for_label(self, label):
|
205 |
+
for i in range(len(self.train_data)):
|
206 |
+
if self.train_data.imgs[i][1] == label:
|
207 |
+
return self.train_data.imgs[i][0].split("/")[-2]
|
208 |
+
|
209 |
+
def get_knn_confidence(self, query_id, k=20):
|
210 |
+
return np.max(
|
211 |
+
np.bincount(self.train_labels[self.knn_scores[query_id][::-1][:k]])
|
212 |
+
)
|
213 |
+
|
214 |
+
|
215 |
+
class CosineCustomDataset(Dataset):
|
216 |
+
r"""Parent class of PFPascal, PFWillow, and SPair"""
|
217 |
+
|
218 |
+
def __init__(self, query_image, supporting_set, source_transform, target_transform):
|
219 |
+
r"""XAICustomDataset constructor"""
|
220 |
+
super(CosineCustomDataset, self).__init__()
|
221 |
+
|
222 |
+
self.supporting_set = supporting_set
|
223 |
+
self.query_image = [query_image] * len(supporting_set)
|
224 |
+
|
225 |
+
self.source_transform = source_transform
|
226 |
+
self.target_transform = target_transform
|
227 |
+
|
228 |
+
def __len__(self):
|
229 |
+
r"""Returns the number of pairs"""
|
230 |
+
return len(self.supporting_set)
|
231 |
+
|
232 |
+
def __getitem__(self, idx):
|
233 |
+
r"""Constructs and return a batch"""
|
234 |
+
|
235 |
+
# Image name
|
236 |
+
batch = dict()
|
237 |
+
batch["src_imname"] = self.query_image[idx]
|
238 |
+
batch["trg_imname"] = self.supporting_set[idx]
|
239 |
+
|
240 |
+
# Image as numpy (original width, original height)
|
241 |
+
src_pil = self.get_image(self.query_image, idx)
|
242 |
+
trg_pil = self.get_image(self.supporting_set, idx)
|
243 |
+
|
244 |
+
batch["src_imsize"] = src_pil.size
|
245 |
+
batch["trg_imsize"] = trg_pil.size
|
246 |
+
|
247 |
+
# Image as tensor
|
248 |
+
batch["src_img"] = self.source_transform(src_pil)
|
249 |
+
batch["trg_img"] = self.target_transform(trg_pil)
|
250 |
+
|
251 |
+
# Total number of pairs in training split
|
252 |
+
batch["datalen"] = len(self.query_image)
|
253 |
+
return batch
|
254 |
+
|
255 |
+
def get_image(self, image_pathes, idx):
|
256 |
+
r"""Reads PIL image from path"""
|
257 |
+
path = image_pathes[idx]
|
258 |
+
return Image.open(path).convert("RGB")
|
259 |
+
|
260 |
+
|
261 |
+
class PairedLayer4Extractor(torch.nn.Module):
|
262 |
+
"""
|
263 |
+
Extracting layer-4 embedding for source and target images using ResNet-50 features
|
264 |
+
"""
|
265 |
+
|
266 |
+
def __init__(self):
|
267 |
+
super(PairedLayer4Extractor, self).__init__()
|
268 |
+
|
269 |
+
self.modelA = models.resnet50(pretrained=True)
|
270 |
+
self.modelA.eval()
|
271 |
+
|
272 |
+
self.modelB = models.resnet50(pretrained=True)
|
273 |
+
self.modelB.eval()
|
274 |
+
|
275 |
+
self.a_embeddings = None
|
276 |
+
self.b_embeddings = None
|
277 |
+
|
278 |
+
def a_hook(module, input, output):
|
279 |
+
self.a_embeddings = output
|
280 |
+
|
281 |
+
def b_hook(module, input, output):
|
282 |
+
self.b_embeddings = output
|
283 |
+
|
284 |
+
self.modelA._modules.get("layer4").register_forward_hook(a_hook)
|
285 |
+
self.modelB._modules.get("layer4").register_forward_hook(b_hook)
|
286 |
+
|
287 |
+
def forward(self, inputs):
|
288 |
+
inputA, inputB = inputs
|
289 |
+
self.modelA(inputA)
|
290 |
+
self.modelB(inputB)
|
291 |
+
|
292 |
+
return self.a_embeddings, self.b_embeddings
|
293 |
+
|
294 |
+
def __repr__(self):
|
295 |
+
return "PairedLayer4Extractor"
|
296 |
+
|
297 |
+
|
298 |
+
class iNaturalistPairedLayer4Extractor(torch.nn.Module):
|
299 |
+
"""
|
300 |
+
Extracting layer-4 embedding for source and target images using iNaturalist ResNet-50 features
|
301 |
+
"""
|
302 |
+
|
303 |
+
def __init__(self):
|
304 |
+
super(iNaturalistPairedLayer4Extractor, self).__init__()
|
305 |
+
|
306 |
+
self.modelA = resnet50_features(inat=True, pretrained=True)
|
307 |
+
self.modelA.eval()
|
308 |
+
|
309 |
+
self.modelB = resnet50_features(inat=True, pretrained=True)
|
310 |
+
self.modelB.eval()
|
311 |
+
|
312 |
+
self.source_embedding = None
|
313 |
+
self.target_embedding = None
|
314 |
+
|
315 |
+
def forward(self, inputs):
|
316 |
+
source_image, target_image = inputs
|
317 |
+
self.source_embedding = self.modelA(source_image)
|
318 |
+
self.target_embedding = self.modelB(target_image)
|
319 |
+
|
320 |
+
return self.source_embedding, self.target_embedding
|
321 |
+
|
322 |
+
def __repr__(self):
|
323 |
+
return "iNatPairedLayer4Extractor"
|
app.py
CHANGED
@@ -10,12 +10,21 @@ from torchvision.datasets import ImageFolder
|
|
10 |
|
11 |
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
|
12 |
from ExtractEmbedding import QueryToEmbedding
|
|
|
|
|
13 |
|
14 |
csv.field_size_limit(sys.maxsize)
|
15 |
|
16 |
concat = lambda x: np.concatenate(x, axis=0)
|
17 |
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
|
20 |
|
21 |
# CUB training set
|
@@ -26,13 +35,21 @@ gdown.cached_download(
|
|
26 |
md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
|
27 |
)
|
28 |
|
29 |
-
# EXTRACT
|
30 |
torchvision.datasets.utils.extract_archive(
|
31 |
from_path="CUB_train.zip",
|
32 |
-
to_path="
|
33 |
remove_finished=False,
|
34 |
)
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
# Caluclate Accuracy
|
38 |
with open(f"./embeddings.pickle", "rb") as f:
|
@@ -45,35 +62,53 @@ searcher = SearchableTrainingSet(Xtrain, ytrain)
|
|
45 |
searcher.build_index()
|
46 |
|
47 |
# Extract label names
|
48 |
-
training_folder = ImageFolder(root="./
|
49 |
id_to_bird_name = {
|
50 |
x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
|
51 |
}
|
52 |
|
53 |
|
54 |
-
def search(
|
55 |
-
query_embedding = QueryToEmbedding(
|
56 |
-
|
57 |
|
58 |
result_ctr = Counter(labels[0][:20]).most_common(5)
|
59 |
|
60 |
top1_label = result_ctr[0][0]
|
61 |
top_indices = []
|
62 |
|
63 |
-
for a, b in zip(labels[0][:20],
|
64 |
if a == top1_label:
|
65 |
top_indices.append(b)
|
66 |
|
67 |
gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
|
68 |
predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}
|
69 |
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
|
73 |
demo = gr.Interface(
|
74 |
search,
|
75 |
-
gr.Image(type="
|
76 |
-
["label", "gallery"],
|
77 |
examples=[["./examples/bird.jpg"]],
|
78 |
description="WIP - kNN on CUB dataset",
|
79 |
title="Work in Progress - CHM-Corr",
|
|
|
10 |
|
11 |
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
|
12 |
from ExtractEmbedding import QueryToEmbedding
|
13 |
+
from CHMCorr import chm_classify_and_visualize
|
14 |
+
from visualization import plot_from_reranker_output
|
15 |
|
16 |
csv.field_size_limit(sys.maxsize)
|
17 |
|
18 |
concat = lambda x: np.concatenate(x, axis=0)
|
19 |
|
20 |
+
# Embeddings
|
21 |
+
gdown.cached_download(
|
22 |
+
url="https://drive.google.com/uc?id=116CiA_cXciGSl72tbAUDoN-f1B9Frp89",
|
23 |
+
path="./embeddings.pkl",
|
24 |
+
quiet=False,
|
25 |
+
md5="002b2a7f5c80d910b9cc740c2265f058",
|
26 |
+
)
|
27 |
+
|
28 |
gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
|
29 |
|
30 |
# CUB training set
|
|
|
35 |
md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
|
36 |
)
|
37 |
|
38 |
+
# EXTRACT training set
|
39 |
torchvision.datasets.utils.extract_archive(
|
40 |
from_path="CUB_train.zip",
|
41 |
+
to_path="data/",
|
42 |
remove_finished=False,
|
43 |
)
|
44 |
|
45 |
+
# CHM Weights
|
46 |
+
gdown.cached_download(
|
47 |
+
url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download",
|
48 |
+
path="pas_psi.pt",
|
49 |
+
quiet=False,
|
50 |
+
md5="6b7b4d7bad7f89600fac340d6aa7708b",
|
51 |
+
)
|
52 |
+
|
53 |
|
54 |
# Caluclate Accuracy
|
55 |
with open(f"./embeddings.pickle", "rb") as f:
|
|
|
62 |
searcher.build_index()
|
63 |
|
64 |
# Extract label names
|
65 |
+
training_folder = ImageFolder(root="./data/train/")
|
66 |
id_to_bird_name = {
|
67 |
x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
|
68 |
}
|
69 |
|
70 |
|
71 |
+
def search(query_image, searcher=searcher):
|
72 |
+
query_embedding = QueryToEmbedding(query_image)
|
73 |
+
scores, indices, labels = searcher.search(query_embedding, k=50)
|
74 |
|
75 |
result_ctr = Counter(labels[0][:20]).most_common(5)
|
76 |
|
77 |
top1_label = result_ctr[0][0]
|
78 |
top_indices = []
|
79 |
|
80 |
+
for a, b in zip(labels[0][:20], indices[0][:20]):
|
81 |
if a == top1_label:
|
82 |
top_indices.append(b)
|
83 |
|
84 |
gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
|
85 |
predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}
|
86 |
|
87 |
+
print("gallery_images:", gallery_images)
|
88 |
+
|
89 |
+
# CHM Prediction
|
90 |
+
kNN_results = (top1_label, result_ctr[0][1], gallery_images)
|
91 |
+
support_files = [training_folder.imgs[int(X)][0] for X in indices[0]]
|
92 |
+
|
93 |
+
print(support_files)
|
94 |
+
support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]]
|
95 |
+
print(support_labels)
|
96 |
+
|
97 |
+
support = [support_files, support_labels]
|
98 |
+
|
99 |
+
chm_output = chm_classify_and_visualize(
|
100 |
+
query_image, kNN_results, support, training_folder
|
101 |
+
)
|
102 |
+
|
103 |
+
viz_plot = plot_from_reranker_output(chm_output, draw_arcs=False)
|
104 |
+
|
105 |
+
return predicted_labels, gallery_images, viz_plot
|
106 |
|
107 |
|
108 |
demo = gr.Interface(
|
109 |
search,
|
110 |
+
gr.Image(type="filepath"),
|
111 |
+
["label", "gallery", "plot"],
|
112 |
examples=[["./examples/bird.jpg"]],
|
113 |
description="WIP - kNN on CUB dataset",
|
114 |
title="Work in Progress - CHM-Corr",
|
common/evaluation.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Evaluates CHMNet with PCK """
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class Evaluator:
|
7 |
+
r""" Computes evaluation metrics of PCK """
|
8 |
+
@classmethod
|
9 |
+
def initialize(cls, alpha):
|
10 |
+
cls.alpha = torch.tensor(alpha).unsqueeze(1)
|
11 |
+
|
12 |
+
@classmethod
|
13 |
+
def evaluate(cls, prd_kps, batch):
|
14 |
+
r""" Compute percentage of correct key-points (PCK) with multiple alpha {0.05, 0.1, 0.15 }"""
|
15 |
+
|
16 |
+
pcks = []
|
17 |
+
for idx, (pk, tk) in enumerate(zip(prd_kps, batch['trg_kps'])):
|
18 |
+
pckthres = batch['pckthres'][idx]
|
19 |
+
npt = batch['n_pts'][idx]
|
20 |
+
prd_kps = pk[:, :npt]
|
21 |
+
trg_kps = tk[:, :npt]
|
22 |
+
|
23 |
+
l2dist = (prd_kps - trg_kps).pow(2).sum(dim=0).pow(0.5).unsqueeze(0).repeat(len(cls.alpha), 1)
|
24 |
+
thres = pckthres.expand_as(l2dist).float() * cls.alpha
|
25 |
+
pck = torch.le(l2dist, thres).sum(dim=1) / float(npt)
|
26 |
+
if len(pck) == 1: pck = pck[0]
|
27 |
+
pcks.append(pck)
|
28 |
+
|
29 |
+
eval_result = {'pck': pcks}
|
30 |
+
|
31 |
+
return eval_result
|
32 |
+
|
common/logger.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Logging """
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
|
7 |
+
from tensorboardX import SummaryWriter
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class Logger:
|
12 |
+
r""" Writes results of training/testing """
|
13 |
+
@classmethod
|
14 |
+
def initialize(cls, args, training):
|
15 |
+
logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
|
16 |
+
logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-1].split('.')[0] + logtime
|
17 |
+
if logpath == '': logpath = logtime
|
18 |
+
|
19 |
+
cls.logpath = os.path.join('logs', logpath + '.log')
|
20 |
+
cls.benchmark = args.benchmark
|
21 |
+
os.makedirs(cls.logpath)
|
22 |
+
|
23 |
+
logging.basicConfig(filemode='w',
|
24 |
+
filename=os.path.join(cls.logpath, 'log.txt'),
|
25 |
+
level=logging.INFO,
|
26 |
+
format='%(message)s',
|
27 |
+
datefmt='%m-%d %H:%M:%S')
|
28 |
+
|
29 |
+
# Console log config
|
30 |
+
console = logging.StreamHandler()
|
31 |
+
console.setLevel(logging.INFO)
|
32 |
+
formatter = logging.Formatter('%(message)s')
|
33 |
+
console.setFormatter(formatter)
|
34 |
+
logging.getLogger('').addHandler(console)
|
35 |
+
|
36 |
+
# Tensorboard writer
|
37 |
+
cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
|
38 |
+
|
39 |
+
# Log arguments
|
40 |
+
if training:
|
41 |
+
logging.info(':======== Convolutional Hough Matching Networks =========')
|
42 |
+
for arg_key in args.__dict__:
|
43 |
+
logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
|
44 |
+
logging.info(':========================================================\n')
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def info(cls, msg):
|
48 |
+
r""" Writes message to .txt """
|
49 |
+
logging.info(msg)
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def save_model(cls, model, epoch, val_pck):
|
53 |
+
torch.save(model.state_dict(), os.path.join(cls.logpath, 'pck_best_model.pt'))
|
54 |
+
cls.info('Model saved @%d w/ val. PCK: %5.2f.\n' % (epoch, val_pck))
|
55 |
+
|
56 |
+
|
57 |
+
class AverageMeter:
|
58 |
+
r""" Stores loss, evaluation results, selected layers """
|
59 |
+
def __init__(self, benchamrk):
|
60 |
+
r""" Constructor of AverageMeter """
|
61 |
+
self.buffer_keys = ['pck']
|
62 |
+
self.buffer = {}
|
63 |
+
for key in self.buffer_keys:
|
64 |
+
self.buffer[key] = []
|
65 |
+
|
66 |
+
self.loss_buffer = []
|
67 |
+
|
68 |
+
def update(self, eval_result, loss=None):
|
69 |
+
for key in self.buffer_keys:
|
70 |
+
self.buffer[key] += eval_result[key]
|
71 |
+
|
72 |
+
if loss is not None:
|
73 |
+
self.loss_buffer.append(loss)
|
74 |
+
|
75 |
+
def write_result(self, split, epoch):
|
76 |
+
msg = '\n*** %s ' % split
|
77 |
+
msg += '[@Epoch %02d] ' % epoch
|
78 |
+
|
79 |
+
if len(self.loss_buffer) > 0:
|
80 |
+
msg += 'Loss: %5.2f ' % (sum(self.loss_buffer) / len(self.loss_buffer))
|
81 |
+
|
82 |
+
for key in self.buffer_keys:
|
83 |
+
msg += '%s: %6.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
|
84 |
+
msg += '***\n'
|
85 |
+
Logger.info(msg)
|
86 |
+
|
87 |
+
def write_process(self, batch_idx, datalen, epoch):
|
88 |
+
msg = '[Epoch: %02d] ' % epoch
|
89 |
+
msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
|
90 |
+
if len(self.loss_buffer) > 0:
|
91 |
+
msg += 'Loss: %5.2f ' % self.loss_buffer[-1]
|
92 |
+
msg += 'Avg Loss: %5.5f ' % (sum(self.loss_buffer) / len(self.loss_buffer))
|
93 |
+
|
94 |
+
for key in self.buffer_keys:
|
95 |
+
msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]) * 100)
|
96 |
+
Logger.info(msg)
|
97 |
+
|
98 |
+
def write_test_process(self, batch_idx, datalen):
|
99 |
+
msg = '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
|
100 |
+
|
101 |
+
for key in self.buffer_keys:
|
102 |
+
if key == 'pck':
|
103 |
+
pcks = torch.stack(self.buffer[key]).mean(dim=0) * 100
|
104 |
+
val = ''
|
105 |
+
for p in pcks:
|
106 |
+
val += '%5.2f ' % p.item()
|
107 |
+
msg += 'Avg %s: %s ' % (key.upper(), val)
|
108 |
+
else:
|
109 |
+
msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
|
110 |
+
Logger.info(msg)
|
111 |
+
|
112 |
+
def get_test_result(self):
|
113 |
+
result = {}
|
114 |
+
for key in self.buffer_keys:
|
115 |
+
result[key] = torch.stack(self.buffer[key]).mean(dim=0) * 100
|
116 |
+
|
117 |
+
return result
|
examples/Red_Winged_Blackbird_0012_6015.jpg
ADDED
examples/Red_Winged_Blackbird_0025_5342.jpg
ADDED
examples/Yellow_Headed_Blackbird_0020_8549.jpg
ADDED
examples/Yellow_Headed_Blackbird_0026_8545.jpg
ADDED
examples/sample1.jpeg
ADDED
examples/sample2.jpeg
ADDED
model/base/backbone.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" ResNet-101 backbone network """
|
2 |
+
|
3 |
+
import torch.utils.model_zoo as model_zoo
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
__all__ = ['Backbone', 'resnet101']
|
9 |
+
|
10 |
+
|
11 |
+
model_urls = {
|
12 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
13 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
14 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
15 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
16 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
21 |
+
r""" 3x3 convolution with padding """
|
22 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
23 |
+
padding=1, groups=2, bias=False)
|
24 |
+
|
25 |
+
|
26 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
27 |
+
r""" 1x1 convolution """
|
28 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, groups=2, bias=False)
|
29 |
+
|
30 |
+
|
31 |
+
class Bottleneck(nn.Module):
|
32 |
+
expansion = 4
|
33 |
+
|
34 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
35 |
+
super(Bottleneck, self).__init__()
|
36 |
+
self.conv1 = conv1x1(inplanes, planes)
|
37 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
38 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
39 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
40 |
+
self.conv3 = conv1x1(planes, planes * self.expansion)
|
41 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
42 |
+
self.relu = nn.ReLU(inplace=True)
|
43 |
+
self.downsample = downsample
|
44 |
+
self.stride = stride
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
identity = x
|
48 |
+
|
49 |
+
out = self.conv1(x)
|
50 |
+
out = self.bn1(out)
|
51 |
+
out = self.relu(out)
|
52 |
+
|
53 |
+
out = self.conv2(out)
|
54 |
+
out = self.bn2(out)
|
55 |
+
out = self.relu(out)
|
56 |
+
|
57 |
+
out = self.conv3(out)
|
58 |
+
out = self.bn3(out)
|
59 |
+
|
60 |
+
if self.downsample is not None:
|
61 |
+
identity = self.downsample(x)
|
62 |
+
|
63 |
+
out += identity
|
64 |
+
out = self.relu(out)
|
65 |
+
|
66 |
+
return out
|
67 |
+
|
68 |
+
|
69 |
+
class Backbone(nn.Module):
|
70 |
+
def __init__(self, block, layers, zero_init_residual=False):
|
71 |
+
super(Backbone, self).__init__()
|
72 |
+
|
73 |
+
self.inplanes = 128
|
74 |
+
self.conv1 = nn.Conv2d(6, 128, kernel_size=7, stride=2, padding=3, groups=2,
|
75 |
+
bias=False)
|
76 |
+
self.bn1 = nn.BatchNorm2d(128)
|
77 |
+
self.relu = nn.ReLU(inplace=True)
|
78 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
79 |
+
self.layer1 = self._make_layer(block, 128, layers[0])
|
80 |
+
self.layer2 = self._make_layer(block, 256, layers[1], stride=2)
|
81 |
+
self.layer3 = self._make_layer(block, 512, layers[2], stride=2)
|
82 |
+
self.layer4 = self._make_layer(block, 1024, layers[3], stride=2)
|
83 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
84 |
+
self.fc = nn.Linear(512 * block.expansion, 1000)
|
85 |
+
|
86 |
+
for m in self.modules():
|
87 |
+
if isinstance(m, nn.Conv2d):
|
88 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
89 |
+
elif isinstance(m, nn.BatchNorm2d):
|
90 |
+
nn.init.constant_(m.weight, 1)
|
91 |
+
nn.init.constant_(m.bias, 0)
|
92 |
+
|
93 |
+
# Zero-initialize the last BN in each residual branch,
|
94 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
95 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
96 |
+
if zero_init_residual:
|
97 |
+
for m in self.modules():
|
98 |
+
if isinstance(m, Bottleneck):
|
99 |
+
nn.init.constant_(m.bn3.weight, 0)
|
100 |
+
|
101 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
102 |
+
downsample = None
|
103 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
104 |
+
downsample = nn.Sequential(
|
105 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
106 |
+
nn.BatchNorm2d(planes * block.expansion),
|
107 |
+
)
|
108 |
+
|
109 |
+
layers = []
|
110 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
111 |
+
self.inplanes = planes * block.expansion
|
112 |
+
for _ in range(1, blocks):
|
113 |
+
layers.append(block(self.inplanes, planes))
|
114 |
+
|
115 |
+
return nn.Sequential(*layers)
|
116 |
+
|
117 |
+
|
118 |
+
def resnet101(pretrained=False, **kwargs):
|
119 |
+
"""Constructs a ResNet-101 model.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
123 |
+
"""
|
124 |
+
model = Backbone(Bottleneck, [3, 4, 23, 3], **kwargs)
|
125 |
+
if pretrained:
|
126 |
+
weights = model_zoo.load_url(model_urls['resnet101'])
|
127 |
+
|
128 |
+
for key in weights:
|
129 |
+
if key.split('.')[0] == 'fc':
|
130 |
+
weights[key] = weights[key].clone()
|
131 |
+
continue
|
132 |
+
weights[key] = torch.cat([weights[key].clone(), weights[key].clone()], dim=0)
|
133 |
+
|
134 |
+
model.load_state_dict(weights)
|
135 |
+
return model
|
136 |
+
|
model/base/chm.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" 4D and 6D convolutional Hough matching layers """
|
2 |
+
|
3 |
+
from torch.nn.modules.conv import _ConvNd
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from common.logger import Logger
|
9 |
+
from . import chm_kernel
|
10 |
+
|
11 |
+
|
12 |
+
def fast4d(corr, kernel, bias=None):
|
13 |
+
r""" Optimized implementation of 4D convolution """
|
14 |
+
bsz, ch, srch, srcw, trgh, trgw = corr.size()
|
15 |
+
out_channels, _, kernel_size, kernel_size, kernel_size, kernel_size = kernel.size()
|
16 |
+
psz = kernel_size // 2
|
17 |
+
|
18 |
+
out_corr = torch.zeros((bsz, out_channels, srch, srcw, trgh, trgw))
|
19 |
+
corr = corr.transpose(1, 2).contiguous().view(bsz * srch, ch, srcw, trgh, trgw)
|
20 |
+
|
21 |
+
for pidx, k3d in enumerate(kernel.permute(2, 0, 1, 3, 4, 5)):
|
22 |
+
inter_corr = F.conv3d(corr, k3d, bias=None, stride=1, padding=psz)
|
23 |
+
inter_corr = inter_corr.view(bsz, srch, out_channels, srcw, trgh, trgw).transpose(1, 2).contiguous()
|
24 |
+
|
25 |
+
add_sid = max(psz - pidx, 0)
|
26 |
+
add_fid = min(srch, srch + psz - pidx)
|
27 |
+
slc_sid = max(pidx - psz, 0)
|
28 |
+
slc_fid = min(srch, srch - psz + pidx)
|
29 |
+
|
30 |
+
out_corr[:, :, add_sid:add_fid, :, :, :] += inter_corr[:, :, slc_sid:slc_fid, :, :, :]
|
31 |
+
|
32 |
+
if bias is not None:
|
33 |
+
out_corr += bias.view(1, out_channels, 1, 1, 1, 1)
|
34 |
+
|
35 |
+
return out_corr
|
36 |
+
|
37 |
+
|
38 |
+
def fast6d(corr, kernel, bias, diagonal_idx):
|
39 |
+
r""" Optimized implementation of 6D convolutional Hough matching
|
40 |
+
NOTE: this function only supports kernel size of (3, 3, 5, 5, 5, 5).
|
41 |
+
r"""
|
42 |
+
bsz, _, s6d, s6d, s4d, s4d, s4d, s4d = corr.size()
|
43 |
+
_, _, ks6d, ks6d, ks4d, ks4d, ks4d, ks4d = kernel.size()
|
44 |
+
corr = corr.permute(0, 2, 3, 1, 4, 5, 6, 7).contiguous().view(-1, 1, s4d, s4d, s4d, s4d)
|
45 |
+
kernel = kernel.view(-1, ks6d ** 2, ks4d, ks4d, ks4d, ks4d).transpose(0, 1)
|
46 |
+
corr = fast4d(corr, kernel).view(bsz, s6d * s6d, ks6d * ks6d, s4d, s4d, s4d, s4d)
|
47 |
+
corr = corr.view(bsz, s6d, s6d, ks6d, ks6d, s4d, s4d, s4d, s4d).transpose(2, 3).\
|
48 |
+
contiguous().view(-1, s6d * ks6d, s4d, s4d, s4d, s4d)
|
49 |
+
|
50 |
+
ndiag = s6d + (ks6d // 2) * 2
|
51 |
+
first_sum = []
|
52 |
+
for didx in diagonal_idx:
|
53 |
+
first_sum.append(corr[:, didx, :, :, :, :].sum(dim=1))
|
54 |
+
first_sum = torch.stack(first_sum).transpose(0, 1).view(bsz, s6d * ks6d, ndiag, s4d, s4d, s4d, s4d)
|
55 |
+
|
56 |
+
corr = []
|
57 |
+
for didx in diagonal_idx:
|
58 |
+
corr.append(first_sum[:, didx, :, :, :, :, :].sum(dim=1))
|
59 |
+
sidx = ks6d // 2
|
60 |
+
eidx = ndiag - sidx
|
61 |
+
corr = torch.stack(corr).transpose(0, 1)[:, sidx:eidx, sidx:eidx, :, :, :, :].unsqueeze(1).contiguous()
|
62 |
+
corr += bias.view(1, -1, 1, 1, 1, 1, 1, 1)
|
63 |
+
|
64 |
+
reverse_idx = torch.linspace(s6d * s6d - 1, 0, s6d * s6d).long()
|
65 |
+
corr = corr.view(bsz, 1, s6d * s6d, s4d, s4d, s4d, s4d)[:, :, reverse_idx, :, :, :, :].\
|
66 |
+
view(bsz, 1, s6d, s6d, s4d, s4d, s4d, s4d)
|
67 |
+
return corr
|
68 |
+
|
69 |
+
def init_param_idx4d(param_dict):
|
70 |
+
param_idx = []
|
71 |
+
for key in param_dict:
|
72 |
+
curr_offset = int(key.split('_')[-1])
|
73 |
+
param_idx.append(torch.tensor(param_dict[key]))
|
74 |
+
return param_idx
|
75 |
+
|
76 |
+
class CHM4d(_ConvNd):
|
77 |
+
r""" 4D convolutional Hough matching layer
|
78 |
+
NOTE: this function only supports in_channels=1 and out_channels=1.
|
79 |
+
r"""
|
80 |
+
def __init__(self, in_channels, out_channels, ksz4d, ktype, bias=True):
|
81 |
+
super(CHM4d, self).__init__(in_channels, out_channels, (ksz4d,) * 4,
|
82 |
+
(1,) * 4, (0,) * 4, (1,) * 4, False, (0,) * 4,
|
83 |
+
1, bias, padding_mode='zeros')
|
84 |
+
|
85 |
+
# Zero kernel initialization
|
86 |
+
self.zero_kernel4d = torch.zeros((in_channels, out_channels, ksz4d, ksz4d, ksz4d, ksz4d))
|
87 |
+
self.nkernels = in_channels * out_channels
|
88 |
+
|
89 |
+
# Initialize kernel indices
|
90 |
+
param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
|
91 |
+
param_shared = param_dict4d is not None
|
92 |
+
|
93 |
+
if param_shared:
|
94 |
+
# Initialize the shared parameters (multiplied by the number of times being shared)
|
95 |
+
self.param_idx = init_param_idx4d(param_dict4d)
|
96 |
+
weights = torch.abs(torch.randn(len(self.param_idx) * self.nkernels)) * 1e-3
|
97 |
+
for weight, param_idx in zip(weights.sort()[0], self.param_idx):
|
98 |
+
weight *= len(param_idx)
|
99 |
+
self.weight = nn.Parameter(weights)
|
100 |
+
else: # full kernel initialziation
|
101 |
+
self.param_idx = None
|
102 |
+
self.weight = nn.Parameter(torch.abs(self.weight))
|
103 |
+
if bias: self.bias = nn.Parameter(torch.tensor(0.0))
|
104 |
+
Logger.info('(%s) # params in CHM 4D: %d' % (ktype, len(self.weight.view(-1))))
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
kernel = self.init_kernel()
|
108 |
+
x = fast4d(x, kernel, self.bias)
|
109 |
+
return x
|
110 |
+
|
111 |
+
def init_kernel(self):
|
112 |
+
# Initialize CHM kernel (divided by the number of times being shared)
|
113 |
+
ksz = self.kernel_size[-1]
|
114 |
+
if self.param_idx is None:
|
115 |
+
kernel = self.weight
|
116 |
+
else:
|
117 |
+
kernel = torch.zeros_like(self.zero_kernel4d)
|
118 |
+
for idx, pdx in enumerate(self.param_idx):
|
119 |
+
kernel = kernel.view(-1, ksz, ksz, ksz, ksz)
|
120 |
+
for jdx, kernel_single in enumerate(kernel):
|
121 |
+
weight = self.weight[idx + jdx * len(self.param_idx)].repeat(len(pdx)) / len(pdx)
|
122 |
+
kernel_single.view(-1)[pdx] += weight
|
123 |
+
kernel = kernel.view(self.in_channels, self.out_channels, ksz, ksz, ksz, ksz)
|
124 |
+
return kernel
|
125 |
+
|
126 |
+
|
127 |
+
class CHM6d(_ConvNd):
|
128 |
+
r""" 6D convolutional Hough matching layer with kernel (3, 3, 5, 5, 5, 5)
|
129 |
+
NOTE: this function only supports in_channels=1 and out_channels=1.
|
130 |
+
r"""
|
131 |
+
def __init__(self, in_channels, out_channels, ksz6d, ksz4d, ktype):
|
132 |
+
kernel_size = (ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d)
|
133 |
+
super(CHM6d, self).__init__(in_channels, out_channels, kernel_size, (1,) * 6,
|
134 |
+
(0,) * 6, (1,) * 6, False, (0,) * 6,
|
135 |
+
1, bias=True, padding_mode='zeros')
|
136 |
+
|
137 |
+
# Zero kernel initialization
|
138 |
+
self.zero_kernel4d = torch.zeros((ksz4d, ksz4d, ksz4d, ksz4d))
|
139 |
+
self.zero_kernel6d = torch.zeros((ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d))
|
140 |
+
self.nkernels = in_channels * out_channels
|
141 |
+
|
142 |
+
# Initialize kernel indices
|
143 |
+
# Indices in scale-space where 4D convolutions are performed (3 by 3 scale-space)
|
144 |
+
self.diagonal_idx = [torch.tensor(x) for x in [[6], [3, 7], [0, 4, 8], [1, 5], [2]]]
|
145 |
+
param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
|
146 |
+
param_shared = param_dict4d is not None
|
147 |
+
|
148 |
+
if param_shared: # psi & iso kernel initialization
|
149 |
+
if ktype == 'psi':
|
150 |
+
self.param_dict6d = [[4], [0, 8], [2, 6], [1, 3, 5, 7]]
|
151 |
+
elif ktype == 'iso':
|
152 |
+
self.param_dict6d = [[0, 4, 8], [2, 6], [1, 3, 5, 7]]
|
153 |
+
self.param_dict6d = [torch.tensor(i) for i in self.param_dict6d]
|
154 |
+
|
155 |
+
# Initialize the shared parameters (multiplied by the number of times being shared)
|
156 |
+
self.param_idx = init_param_idx4d(param_dict4d)
|
157 |
+
self.param = []
|
158 |
+
for param_dict6d in self.param_dict6d:
|
159 |
+
weights = torch.abs(torch.randn(len(self.param_idx))) * 1e-3
|
160 |
+
for weight, param_idx in zip(weights, self.param_idx):
|
161 |
+
weight *= (len(param_idx) * len(param_dict6d))
|
162 |
+
self.param.append(nn.Parameter(weights))
|
163 |
+
self.param = nn.ParameterList(self.param)
|
164 |
+
else: # full kernel initialziation
|
165 |
+
self.param_idx = None
|
166 |
+
self.param = nn.Parameter(torch.abs(self.weight) * 1e-3)
|
167 |
+
Logger.info('(%s) # params in CHM 6D: %d' % (ktype, sum([len(x.view(-1)) for x in self.param])))
|
168 |
+
self.weight = None
|
169 |
+
|
170 |
+
def forward(self, corr):
|
171 |
+
kernel = self.init_kernel()
|
172 |
+
corr = fast6d(corr, kernel, self.bias, self.diagonal_idx)
|
173 |
+
return corr
|
174 |
+
|
175 |
+
def init_kernel(self):
|
176 |
+
# Initialize CHM kernel (divided by the number of times being shared)
|
177 |
+
if self.param_idx is None:
|
178 |
+
return self.param
|
179 |
+
|
180 |
+
kernel6d = torch.zeros_like(self.zero_kernel6d)
|
181 |
+
for idx, (param, param_dict6d) in enumerate(zip(self.param, self.param_dict6d)):
|
182 |
+
ksz4d = self.kernel_size[-1]
|
183 |
+
kernel4d = torch.zeros_like(self.zero_kernel4d)
|
184 |
+
for jdx, pdx in enumerate(self.param_idx):
|
185 |
+
kernel4d.view(-1)[pdx] += ((param[jdx] / len(pdx)) / len(param_dict6d))
|
186 |
+
kernel6d.view(-1, ksz4d, ksz4d, ksz4d, ksz4d)[param_dict6d] += kernel4d.view(ksz4d, ksz4d, ksz4d, ksz4d)
|
187 |
+
kernel6d = kernel6d.unsqueeze(0).unsqueeze(0)
|
188 |
+
|
189 |
+
return kernel6d
|
190 |
+
|
model/base/chm_kernel.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" CHM 4D kernel (psi, iso, and full) generator """
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .geometry import Geometry
|
6 |
+
|
7 |
+
|
8 |
+
class KernelGenerator:
|
9 |
+
def __init__(self, ksz, ktype):
|
10 |
+
self.ksz = ksz
|
11 |
+
self.idx4d = Geometry.init_idx4d(ksz)
|
12 |
+
self.kernel = torch.zeros((ksz, ksz, ksz, ksz))
|
13 |
+
self.center = (ksz // 2, ksz // 2)
|
14 |
+
self.ktype = ktype
|
15 |
+
|
16 |
+
def quadrant(self, crd):
|
17 |
+
if crd[0] < self.center[0]:
|
18 |
+
horz_quad = -1
|
19 |
+
elif crd[0] < self.center[0]:
|
20 |
+
horz_quad = 1
|
21 |
+
else:
|
22 |
+
horz_quad = 0
|
23 |
+
|
24 |
+
if crd[1] < self.center[1]:
|
25 |
+
vert_quad = -1
|
26 |
+
elif crd[1] < self.center[1]:
|
27 |
+
vert_quad = 1
|
28 |
+
else:
|
29 |
+
vert_quad = 0
|
30 |
+
|
31 |
+
return horz_quad, vert_quad
|
32 |
+
|
33 |
+
def generate(self):
|
34 |
+
return None if self.ktype == 'full' else self.generate_chm_kernel()
|
35 |
+
|
36 |
+
def generate_chm_kernel(self):
|
37 |
+
param_dict = {}
|
38 |
+
for idx in self.idx4d:
|
39 |
+
src_i, src_j, trg_i, trg_j = idx
|
40 |
+
d_tail = Geometry.get_distance((src_i, src_j), self.center)
|
41 |
+
d_head = Geometry.get_distance((trg_i, trg_j), self.center)
|
42 |
+
d_off = Geometry.get_distance((src_i, src_j), (trg_i, trg_j))
|
43 |
+
horz_quad, vert_quad = self.quadrant((src_j, src_i))
|
44 |
+
|
45 |
+
src_crd = (src_i, src_j)
|
46 |
+
trg_crd = (trg_i, trg_j)
|
47 |
+
|
48 |
+
key = self.build_key(horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off)
|
49 |
+
coord1d = Geometry.get_coord1d((src_i, src_j, trg_i, trg_j), self.ksz)
|
50 |
+
|
51 |
+
if param_dict.get(key) is None: param_dict[key] = []
|
52 |
+
param_dict[key].append(coord1d)
|
53 |
+
|
54 |
+
return param_dict
|
55 |
+
|
56 |
+
def build_key(self, horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off):
|
57 |
+
|
58 |
+
if self.ktype == 'iso':
|
59 |
+
return '%d' % d_off
|
60 |
+
elif self.ktype == 'psi':
|
61 |
+
d_max = max(d_head, d_tail)
|
62 |
+
d_min = min(d_head, d_tail)
|
63 |
+
return '%d_%d_%d' % (d_max, d_min, d_off)
|
64 |
+
else:
|
65 |
+
raise Exception('not implemented.')
|
66 |
+
|
model/base/correlation.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Provides functions that creates/manipulates correlation matrices """
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
from torch.nn.functional import interpolate as resize
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from .geometry import Geometry
|
9 |
+
|
10 |
+
|
11 |
+
class Correlation:
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def mutual_nn_filter(cls, correlation_matrix, eps=1e-30):
|
15 |
+
r""" Mutual nearest neighbor filtering (Rocco et al. NeurIPS'18 )"""
|
16 |
+
corr_src_max = torch.max(correlation_matrix, dim=2, keepdim=True)[0]
|
17 |
+
corr_trg_max = torch.max(correlation_matrix, dim=1, keepdim=True)[0]
|
18 |
+
corr_src_max[corr_src_max == 0] += eps
|
19 |
+
corr_trg_max[corr_trg_max == 0] += eps
|
20 |
+
|
21 |
+
corr_src = correlation_matrix / corr_src_max
|
22 |
+
corr_trg = correlation_matrix / corr_trg_max
|
23 |
+
|
24 |
+
return correlation_matrix * (corr_src * corr_trg)
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def build_correlation6d(self, src_feat, trg_feat, scales, conv2ds):
|
28 |
+
r""" Build 6-dimensional correlation tensor """
|
29 |
+
|
30 |
+
bsz, _, side, side = src_feat.size()
|
31 |
+
|
32 |
+
# Construct feature pairs with multiple scales
|
33 |
+
_src_feats = []
|
34 |
+
_trg_feats = []
|
35 |
+
for scale, conv in zip(scales, conv2ds):
|
36 |
+
s = (round(side * math.sqrt(scale)),) * 2
|
37 |
+
_src_feat = conv(resize(src_feat, s, mode='bilinear', align_corners=True))
|
38 |
+
_trg_feat = conv(resize(trg_feat, s, mode='bilinear', align_corners=True))
|
39 |
+
_src_feats.append(_src_feat)
|
40 |
+
_trg_feats.append(_trg_feat)
|
41 |
+
|
42 |
+
# Build multiple 4-dimensional correlation tensor
|
43 |
+
corr6d = []
|
44 |
+
for src_feat in _src_feats:
|
45 |
+
ch = src_feat.size(1)
|
46 |
+
|
47 |
+
src_side = src_feat.size(-1)
|
48 |
+
src_feat = src_feat.view(bsz, ch, -1).transpose(1, 2)
|
49 |
+
src_norm = src_feat.norm(p=2, dim=2, keepdim=True)
|
50 |
+
|
51 |
+
for trg_feat in _trg_feats:
|
52 |
+
trg_side = trg_feat.size(-1)
|
53 |
+
trg_feat = trg_feat.view(bsz, ch, -1)
|
54 |
+
trg_norm = trg_feat.norm(p=2, dim=1, keepdim=True)
|
55 |
+
|
56 |
+
correlation = torch.bmm(src_feat, trg_feat) / torch.bmm(src_norm, trg_norm)
|
57 |
+
correlation = correlation.view(bsz, src_side, src_side, trg_side, trg_side).contiguous()
|
58 |
+
corr6d.append(correlation)
|
59 |
+
|
60 |
+
# Resize the spatial sizes of the 4D tensors to the same size
|
61 |
+
for idx, correlation in enumerate(corr6d):
|
62 |
+
corr6d[idx] = Geometry.interpolate4d(correlation, [side, side])
|
63 |
+
|
64 |
+
# Build 6-dimensional correlation tensor
|
65 |
+
corr6d = torch.stack(corr6d).view(len(scales), len(scales),
|
66 |
+
bsz, side, side, side, side).permute(2, 0, 1, 3, 4, 5, 6)
|
67 |
+
return corr6d.clamp(min=0)
|
68 |
+
|
model/base/geometry.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Provides functions that manipulate boxes and points """
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class Geometry(object):
|
10 |
+
|
11 |
+
@classmethod
|
12 |
+
def initialize(cls, img_size):
|
13 |
+
cls.img_size = img_size
|
14 |
+
|
15 |
+
cls.spatial_side = int(img_size / 8)
|
16 |
+
norm_grid1d = torch.linspace(-1, 1, cls.spatial_side)
|
17 |
+
|
18 |
+
cls.norm_grid_x = norm_grid1d.view(1, -1).repeat(cls.spatial_side, 1).view(1, 1, -1)
|
19 |
+
cls.norm_grid_y = norm_grid1d.view(-1, 1).repeat(1, cls.spatial_side).view(1, 1, -1)
|
20 |
+
cls.grid = torch.stack(list(reversed(torch.meshgrid(norm_grid1d, norm_grid1d)))).permute(1, 2, 0)
|
21 |
+
|
22 |
+
cls.feat_idx = torch.arange(0, cls.spatial_side).float()
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def normalize_kps(cls, kps):
|
26 |
+
kps = kps.clone().detach()
|
27 |
+
kps[kps != -2] -= (cls.img_size // 2)
|
28 |
+
kps[kps != -2] /= (cls.img_size // 2)
|
29 |
+
return kps
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def unnormalize_kps(cls, kps):
|
33 |
+
kps = kps.clone().detach()
|
34 |
+
kps[kps != -2] *= (cls.img_size // 2)
|
35 |
+
kps[kps != -2] += (cls.img_size // 2)
|
36 |
+
return kps
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def attentive_indexing(cls, kps, thres=0.1):
|
40 |
+
r"""kps: normalized keypoints x, y (N, 2)
|
41 |
+
returns attentive index map(N, spatial_side, spatial_side)
|
42 |
+
"""
|
43 |
+
nkps = kps.size(0)
|
44 |
+
kps = kps.view(nkps, 1, 1, 2)
|
45 |
+
|
46 |
+
eps = 1e-5
|
47 |
+
attmap = (cls.grid.unsqueeze(0).repeat(nkps, 1, 1, 1) - kps).pow(2).sum(dim=3)
|
48 |
+
attmap = (attmap + eps).pow(0.5)
|
49 |
+
attmap = (thres - attmap).clamp(min=0).view(nkps, -1)
|
50 |
+
attmap = attmap / attmap.sum(dim=1, keepdim=True)
|
51 |
+
attmap = attmap.view(nkps, cls.spatial_side, cls.spatial_side)
|
52 |
+
|
53 |
+
return attmap
|
54 |
+
|
55 |
+
@classmethod
|
56 |
+
def apply_gaussian_kernel(cls, corr, sigma=17):
|
57 |
+
bsz, side, side = corr.size()
|
58 |
+
|
59 |
+
center = corr.max(dim=2)[1]
|
60 |
+
center_y = center // cls.spatial_side
|
61 |
+
center_x = center % cls.spatial_side
|
62 |
+
|
63 |
+
y = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_y.size(1), 1) - center_y.unsqueeze(2)
|
64 |
+
x = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_x.size(1), 1) - center_x.unsqueeze(2)
|
65 |
+
|
66 |
+
y = y.unsqueeze(3).repeat(1, 1, 1, cls.spatial_side)
|
67 |
+
x = x.unsqueeze(2).repeat(1, 1, cls.spatial_side, 1)
|
68 |
+
|
69 |
+
gauss_kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
|
70 |
+
filtered_corr = gauss_kernel * corr.view(bsz, -1, cls.spatial_side, cls.spatial_side)
|
71 |
+
filtered_corr = filtered_corr.view(bsz, side, side)
|
72 |
+
|
73 |
+
return filtered_corr
|
74 |
+
|
75 |
+
@classmethod
|
76 |
+
def transfer_kps(cls, confidence_ts, src_kps, n_pts, normalized):
|
77 |
+
r""" Transfer keypoints by weighted average """
|
78 |
+
|
79 |
+
if not normalized:
|
80 |
+
src_kps = Geometry.normalize_kps(src_kps)
|
81 |
+
confidence_ts = cls.apply_gaussian_kernel(confidence_ts)
|
82 |
+
|
83 |
+
pdf = F.softmax(confidence_ts, dim=2)
|
84 |
+
prd_x = (pdf * cls.norm_grid_x).sum(dim=2)
|
85 |
+
prd_y = (pdf * cls.norm_grid_y).sum(dim=2)
|
86 |
+
|
87 |
+
prd_kps = []
|
88 |
+
for idx, (x, y, src_kp, np) in enumerate(zip(prd_x, prd_y, src_kps, n_pts)):
|
89 |
+
max_pts = src_kp.size()[1]
|
90 |
+
prd_xy = torch.stack([x, y]).t()
|
91 |
+
|
92 |
+
src_kp = src_kp[:, :np].t()
|
93 |
+
attmap = cls.attentive_indexing(src_kp).view(np, -1)
|
94 |
+
prd_kp = (prd_xy.unsqueeze(0) * attmap.unsqueeze(-1)).sum(dim=1).t()
|
95 |
+
pads = (torch.zeros((2, max_pts - np)) - 2)
|
96 |
+
prd_kp = torch.cat([prd_kp, pads], dim=1)
|
97 |
+
prd_kps.append(prd_kp)
|
98 |
+
|
99 |
+
return torch.stack(prd_kps)
|
100 |
+
|
101 |
+
@staticmethod
|
102 |
+
def get_coord1d(coord4d, ksz):
|
103 |
+
i, j, k, l = coord4d
|
104 |
+
coord1d = i * (ksz ** 3) + j * (ksz ** 2) + k * (ksz) + l
|
105 |
+
return coord1d
|
106 |
+
|
107 |
+
@staticmethod
|
108 |
+
def get_distance(coord1, coord2):
|
109 |
+
delta_y = int(math.pow(coord1[0] - coord2[0], 2))
|
110 |
+
delta_x = int(math.pow(coord1[1] - coord2[1], 2))
|
111 |
+
dist = delta_y + delta_x
|
112 |
+
return dist
|
113 |
+
|
114 |
+
@staticmethod
|
115 |
+
def interpolate4d(tensor4d, size):
|
116 |
+
bsz, h1, w1, h2, w2 = tensor4d.size()
|
117 |
+
tensor4d = tensor4d.view(bsz, h1, w1, -1).permute(0, 3, 1, 2)
|
118 |
+
tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
|
119 |
+
tensor4d = tensor4d.view(bsz, h2, w2, -1).permute(0, 3, 1, 2)
|
120 |
+
tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
|
121 |
+
tensor4d = tensor4d.view(bsz, size[0], size[0], size[0], size[0])
|
122 |
+
|
123 |
+
return tensor4d
|
124 |
+
@staticmethod
|
125 |
+
def init_idx4d(ksz):
|
126 |
+
i0 = torch.arange(0, ksz).repeat(ksz ** 3)
|
127 |
+
i1 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz).view(-1).repeat(ksz ** 2)
|
128 |
+
i2 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 2).view(-1).repeat(ksz)
|
129 |
+
i3 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 3).view(-1)
|
130 |
+
idx4d = torch.stack([i3, i2, i1, i0]).t().numpy()
|
131 |
+
|
132 |
+
return idx4d
|
133 |
+
|
model/chmlearner.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Conovlutional Hough matching layers """
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .base.correlation import Correlation
|
7 |
+
from .base.geometry import Geometry
|
8 |
+
from .base.chm import CHM4d, CHM6d
|
9 |
+
|
10 |
+
|
11 |
+
class CHMLearner(nn.Module):
|
12 |
+
|
13 |
+
def __init__(self, ktype, feat_dim):
|
14 |
+
super(CHMLearner, self).__init__()
|
15 |
+
|
16 |
+
# Scale-wise feature transformation
|
17 |
+
self.scales = [0.5, 1, 2]
|
18 |
+
self.conv2ds = nn.ModuleList([nn.Conv2d(feat_dim, feat_dim // 4, kernel_size=3, padding=1, bias=False) for _ in self.scales])
|
19 |
+
|
20 |
+
# CHM layers
|
21 |
+
ksz_translation = 5
|
22 |
+
ksz_scale = 3
|
23 |
+
self.chm6d = CHM6d(1, 1, ksz_scale, ksz_translation, ktype)
|
24 |
+
self.chm4d = CHM4d(1, 1, ksz_translation, ktype, bias=True)
|
25 |
+
|
26 |
+
# Activations
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.sigmoid = nn.Sigmoid()
|
29 |
+
self.softplus = nn.Softplus()
|
30 |
+
|
31 |
+
def forward(self, src_feat, trg_feat):
|
32 |
+
|
33 |
+
corr = Correlation.build_correlation6d(src_feat, trg_feat, self.scales, self.conv2ds).unsqueeze(1)
|
34 |
+
bsz, ch, s, s, h, w, h, w = corr.size()
|
35 |
+
|
36 |
+
# CHM layer (6D)
|
37 |
+
corr = self.chm6d(corr)
|
38 |
+
corr = self.sigmoid(corr)
|
39 |
+
|
40 |
+
# Scale-space maxpool
|
41 |
+
corr = corr.view(bsz, -1, h, w, h, w).max(dim=1)[0]
|
42 |
+
corr = Geometry.interpolate4d(corr, [h * 2, w * 2]).unsqueeze(1)
|
43 |
+
|
44 |
+
# CHM layer (4D)
|
45 |
+
corr = self.chm4d(corr).squeeze(1)
|
46 |
+
|
47 |
+
# To ensure non-negative vote scores & soft cyclic constraints
|
48 |
+
corr = self.softplus(corr)
|
49 |
+
corr = Correlation.mutual_nn_filter(corr.view(bsz, corr.size(-1) ** 2, corr.size(-1) ** 2).contiguous())
|
50 |
+
|
51 |
+
return corr
|
52 |
+
|
model/chmnet.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Convolutional Hough Matching Networks """
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from . import chmlearner as chmlearner
|
7 |
+
from .base import backbone
|
8 |
+
|
9 |
+
|
10 |
+
class CHMNet(nn.Module):
|
11 |
+
def __init__(self, ktype):
|
12 |
+
super(CHMNet, self).__init__()
|
13 |
+
|
14 |
+
self.backbone = backbone.resnet101(pretrained=True)
|
15 |
+
self.learner = chmlearner.CHMLearner(ktype, feat_dim=1024)
|
16 |
+
|
17 |
+
def forward(self, src_img, trg_img):
|
18 |
+
src_feat, trg_feat = self.extract_features(src_img, trg_img)
|
19 |
+
correlation = self.learner(src_feat, trg_feat)
|
20 |
+
return correlation
|
21 |
+
|
22 |
+
def extract_features(self, src_img, trg_img):
|
23 |
+
feat = self.backbone.conv1.forward(torch.cat([src_img, trg_img], dim=1))
|
24 |
+
feat = self.backbone.bn1.forward(feat)
|
25 |
+
feat = self.backbone.relu.forward(feat)
|
26 |
+
feat = self.backbone.maxpool.forward(feat)
|
27 |
+
|
28 |
+
for idx in range(1, 5):
|
29 |
+
feat = self.backbone.__getattr__('layer%d' % idx)(feat)
|
30 |
+
|
31 |
+
if idx == 3:
|
32 |
+
src_feat = feat.narrow(1, 0, feat.size(1) // 2).clone()
|
33 |
+
trg_feat = feat.narrow(1, feat.size(1) // 2, feat.size(1) // 2).clone()
|
34 |
+
return src_feat, trg_feat
|
35 |
+
|
36 |
+
def training_objective(cls, prd_kps, trg_kps, npts):
|
37 |
+
l2dist = (prd_kps - trg_kps).pow(2).sum(dim=1)
|
38 |
+
loss = []
|
39 |
+
for dist, npt in zip(l2dist, npts):
|
40 |
+
loss.append(dist[:npt].mean())
|
41 |
+
return torch.stack(loss).mean()
|
42 |
+
|
visualization.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from collections import Counter
|
3 |
+
from itertools import product
|
4 |
+
|
5 |
+
import matplotlib
|
6 |
+
import matplotlib.patches as patches
|
7 |
+
import numpy as np
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from matplotlib import gridspec
|
10 |
+
from matplotlib import pyplot as plt
|
11 |
+
from matplotlib.patches import ConnectionPatch, ConnectionStyle
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
connectionstyle = ConnectionStyle("Arc3, rad=0.2")
|
15 |
+
|
16 |
+
display_transform = transforms.Compose(
|
17 |
+
[transforms.Resize(240), transforms.CenterCrop((240, 240))]
|
18 |
+
)
|
19 |
+
display_transform_knn = transforms.Compose(
|
20 |
+
[transforms.Resize(256), transforms.CenterCrop((224, 224))]
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
def keep_top_k(input_array, K=5):
|
25 |
+
"""
|
26 |
+
return top 5 (k) from numpy array
|
27 |
+
"""
|
28 |
+
top_5 = np.sort(input_array.reshape(-1))[::-1][K - 1]
|
29 |
+
masked = np.zeros_like(input_array)
|
30 |
+
masked[input_array >= top_5] = 1
|
31 |
+
return masked
|
32 |
+
|
33 |
+
|
34 |
+
def arg_topK(inputarray, topK=5):
|
35 |
+
"""
|
36 |
+
returns indicies related to top K element (largest)
|
37 |
+
"""
|
38 |
+
return np.argsort(inputarray.T.reshape(-1))[::-1][:topK]
|
39 |
+
|
40 |
+
|
41 |
+
# FOR MULTI
|
42 |
+
def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True):
|
43 |
+
"""
|
44 |
+
visualize chm results from a reranker output dict
|
45 |
+
"""
|
46 |
+
|
47 |
+
### SET COLORS
|
48 |
+
cmap = matplotlib.cm.get_cmap("gist_rainbow")
|
49 |
+
rgba = cmap(0.5)
|
50 |
+
colors = []
|
51 |
+
for k in range(5):
|
52 |
+
colors.append(cmap(k / 5.0))
|
53 |
+
|
54 |
+
### SET POINTS
|
55 |
+
A = np.linspace(1 + 17, 240 - 17 - 1, 7)
|
56 |
+
point_list = list(product(A, A))
|
57 |
+
|
58 |
+
nrow = 4
|
59 |
+
ncol = 7
|
60 |
+
|
61 |
+
fig = plt.figure(figsize=(32, 18))
|
62 |
+
gs = gridspec.GridSpec(
|
63 |
+
nrow,
|
64 |
+
ncol,
|
65 |
+
width_ratios=[1, 0.2, 1, 1, 1, 1, 1],
|
66 |
+
height_ratios=[1, 1, 1, 1],
|
67 |
+
wspace=0.1,
|
68 |
+
hspace=0.1,
|
69 |
+
top=0.9,
|
70 |
+
bottom=0.05,
|
71 |
+
left=0.17,
|
72 |
+
right=0.845,
|
73 |
+
)
|
74 |
+
axes = [[None for n in range(ncol - 1)] for x in range(nrow)]
|
75 |
+
|
76 |
+
for i in range(4):
|
77 |
+
axes[i] = []
|
78 |
+
for j in range(7):
|
79 |
+
if j != 1:
|
80 |
+
if (i, j) in [(2, 0), (3, 0)]:
|
81 |
+
axes[i].append(new_ax)
|
82 |
+
else:
|
83 |
+
new_ax = plt.subplot(gs[i, j])
|
84 |
+
new_ax.set_xticklabels([])
|
85 |
+
new_ax.set_xticks([])
|
86 |
+
new_ax.set_yticklabels([])
|
87 |
+
new_ax.set_yticks([])
|
88 |
+
new_ax.axis("off")
|
89 |
+
axes[i].append(new_ax)
|
90 |
+
|
91 |
+
##################### DRAW EVERYTHING
|
92 |
+
axes[0][0].imshow(
|
93 |
+
display_transform(Image.open(reranker_output["q"]).convert("RGB"))
|
94 |
+
)
|
95 |
+
axes[0][0].set_title(
|
96 |
+
f'Query - K={reranker_output["K"]}, N={reranker_output["N"]}', fontsize=21
|
97 |
+
)
|
98 |
+
|
99 |
+
axes[1][0].imshow(
|
100 |
+
display_transform(Image.open(reranker_output["q"]).convert("RGB"))
|
101 |
+
)
|
102 |
+
axes[1][0].set_title(f'Query - K={reranker_output["K"]}', fontsize=21)
|
103 |
+
|
104 |
+
# axes[2][0].imshow(display_transform(Image.open(reranker_output['q'])))
|
105 |
+
|
106 |
+
# CHM Top5
|
107 |
+
for i in range(min(5, reranker_output["chm-prediction-confidence"])):
|
108 |
+
axes[0][1 + i].imshow(
|
109 |
+
display_transform(
|
110 |
+
Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB")
|
111 |
+
)
|
112 |
+
)
|
113 |
+
axes[0][1 + i].set_title(f"CHM - Top - {i+1}", fontsize=21)
|
114 |
+
|
115 |
+
if reranker_output["chm-prediction-confidence"] < 5:
|
116 |
+
for i in range(reranker_output["chm-prediction-confidence"], 5):
|
117 |
+
axes[0][1 + i].imshow(Image.new(mode="RGB", size=(224, 224), color="white"))
|
118 |
+
axes[0][1 + i].set_title(f"", fontsize=21)
|
119 |
+
|
120 |
+
# KNN top5
|
121 |
+
for i in range(min(5, reranker_output["knn-prediction-confidence"])):
|
122 |
+
axes[1][1 + i].imshow(
|
123 |
+
display_transform_knn(
|
124 |
+
Image.open(reranker_output["knn-nearest-neighbors"][i]).convert("RGB")
|
125 |
+
)
|
126 |
+
)
|
127 |
+
axes[1][1 + i].set_title(f"kNN - Top - {i+1}", fontsize=21)
|
128 |
+
|
129 |
+
if reranker_output["knn-prediction-confidence"] < 5:
|
130 |
+
for i in range(reranker_output["knn-prediction-confidence"], 5):
|
131 |
+
axes[1][1 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
|
132 |
+
axes[1][1 + i].set_title(f"", fontsize=21)
|
133 |
+
|
134 |
+
for i in range(min(5, reranker_output["chm-prediction-confidence"])):
|
135 |
+
axes[2][i + 1].imshow(
|
136 |
+
display_transform(Image.open(reranker_output["q"]).convert("RGB"))
|
137 |
+
)
|
138 |
+
|
139 |
+
# Lower ROWs CHM Top5
|
140 |
+
for i in range(min(5, reranker_output["chm-prediction-confidence"])):
|
141 |
+
axes[3][1 + i].imshow(
|
142 |
+
display_transform(
|
143 |
+
Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB")
|
144 |
+
)
|
145 |
+
)
|
146 |
+
|
147 |
+
if reranker_output["chm-prediction-confidence"] < 5:
|
148 |
+
for i in range(reranker_output["chm-prediction-confidence"], 5):
|
149 |
+
axes[2][i + 1].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
|
150 |
+
axes[3][1 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
|
151 |
+
|
152 |
+
nzm = reranker_output["non_zero_mask"]
|
153 |
+
# Go throught top 5 nearest images
|
154 |
+
|
155 |
+
# #################################################################################
|
156 |
+
if draw_box:
|
157 |
+
# SQUARAES
|
158 |
+
for NC in range(min(5, reranker_output["chm-prediction-confidence"])):
|
159 |
+
# ON SOURCE
|
160 |
+
valid_patches_source = arg_topK(
|
161 |
+
reranker_output["masked_cos_values"][NC], topK=nzm
|
162 |
+
)
|
163 |
+
|
164 |
+
# ON QUERY
|
165 |
+
target_masked_patches = arg_topK(
|
166 |
+
reranker_output["masked_cos_values"][NC], topK=nzm
|
167 |
+
)
|
168 |
+
valid_patches_target = [
|
169 |
+
reranker_output["correspondance_map"][NC][x]
|
170 |
+
for x in target_masked_patches
|
171 |
+
]
|
172 |
+
valid_patches_target = [(x[0] * 7) + x[1] for x in valid_patches_target]
|
173 |
+
|
174 |
+
patch_colors = [c for c in colors]
|
175 |
+
overlaps = [
|
176 |
+
item
|
177 |
+
for item, count in Counter(valid_patches_target).items()
|
178 |
+
if count > 1
|
179 |
+
]
|
180 |
+
|
181 |
+
for O in overlaps:
|
182 |
+
indices = [i for i, val in enumerate(valid_patches_target) if val == O]
|
183 |
+
for ii in indices[1:]:
|
184 |
+
patch_colors[ii] = patch_colors[indices[0]]
|
185 |
+
|
186 |
+
for i in valid_patches_source:
|
187 |
+
Psource = point_list[i]
|
188 |
+
rect = patches.Rectangle(
|
189 |
+
(Psource[0] - 16, Psource[1] - 16),
|
190 |
+
32,
|
191 |
+
32,
|
192 |
+
linewidth=2,
|
193 |
+
edgecolor=patch_colors[valid_patches_source.tolist().index(i)],
|
194 |
+
facecolor="none",
|
195 |
+
alpha=1,
|
196 |
+
)
|
197 |
+
axes[2][1 + NC].add_patch(rect)
|
198 |
+
|
199 |
+
for i in valid_patches_target:
|
200 |
+
Psource = point_list[i]
|
201 |
+
rect = patches.Rectangle(
|
202 |
+
(Psource[0] - 16, Psource[1] - 16),
|
203 |
+
32,
|
204 |
+
32,
|
205 |
+
linewidth=2,
|
206 |
+
edgecolor=patch_colors[valid_patches_target.index(i)],
|
207 |
+
facecolor="none",
|
208 |
+
alpha=1,
|
209 |
+
)
|
210 |
+
axes[3][1 + NC].add_patch(rect)
|
211 |
+
|
212 |
+
#################################################################################
|
213 |
+
# Show correspondence lines and points
|
214 |
+
if draw_arcs:
|
215 |
+
for CK in range(min(5, reranker_output["chm-prediction-confidence"])):
|
216 |
+
target_keypoints = []
|
217 |
+
topk_index = arg_topK(reranker_output["masked_cos_values"][CK], topK=nzm)
|
218 |
+
for i in range(nzm): # Number of Connections
|
219 |
+
con = ConnectionPatch(
|
220 |
+
xyA=(
|
221 |
+
reranker_output["src-keypoints"][CK][i, 0],
|
222 |
+
reranker_output["src-keypoints"][CK][i, 1],
|
223 |
+
),
|
224 |
+
xyB=(
|
225 |
+
reranker_output["tgt-keypoints"][CK][i, 0],
|
226 |
+
reranker_output["tgt-keypoints"][CK][i, 1],
|
227 |
+
),
|
228 |
+
coordsA="data",
|
229 |
+
coordsB="data",
|
230 |
+
axesA=axes[2][1 + CK],
|
231 |
+
axesB=axes[3][1 + CK],
|
232 |
+
color=colors[i],
|
233 |
+
connectionstyle=connectionstyle,
|
234 |
+
shrinkA=1.0,
|
235 |
+
shrinkB=1.0,
|
236 |
+
linewidth=1,
|
237 |
+
)
|
238 |
+
|
239 |
+
axes[3][1 + CK].add_artist(con)
|
240 |
+
|
241 |
+
# Scatter Plot
|
242 |
+
axes[2][1 + CK].scatter(
|
243 |
+
reranker_output["src-keypoints"][CK][:, 0],
|
244 |
+
reranker_output["src-keypoints"][CK][:, 1],
|
245 |
+
c=colors[:nzm],
|
246 |
+
s=10,
|
247 |
+
)
|
248 |
+
axes[3][1 + CK].scatter(
|
249 |
+
reranker_output["tgt-keypoints"][CK][:, 0],
|
250 |
+
reranker_output["tgt-keypoints"][CK][:, 1],
|
251 |
+
c=colors[:nzm],
|
252 |
+
s=10,
|
253 |
+
)
|
254 |
+
|
255 |
+
fig.text(
|
256 |
+
0.5,
|
257 |
+
0.95,
|
258 |
+
f"CHM: {reranker_output['chm-prediction']}",
|
259 |
+
ha="center",
|
260 |
+
va="bottom",
|
261 |
+
color="black",
|
262 |
+
fontsize=22,
|
263 |
+
)
|
264 |
+
fig.text(
|
265 |
+
0.8,
|
266 |
+
0.95,
|
267 |
+
f"KNN: {reranker_output['knn-prediction']}",
|
268 |
+
ha="right",
|
269 |
+
va="bottom",
|
270 |
+
color="black",
|
271 |
+
fontsize=22,
|
272 |
+
)
|
273 |
+
|
274 |
+
return fig
|