Spaces:
Running
Running
Upload models/model.py
Browse files- models/model.py +196 -0
models/model.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from models.network import HourGlass2, SpixelNet, ColorProbNet
|
5 |
+
from models.transformer2d import EncoderLayer, DecoderLayer, TransformerEncoder, TransformerDecoder
|
6 |
+
from models.position_encoding import build_position_encoding
|
7 |
+
from models import basic, clusterkit, anchor_gen
|
8 |
+
from collections import OrderedDict
|
9 |
+
from utils import util, cielab
|
10 |
+
|
11 |
+
|
12 |
+
class SpixelSeg(nn.Module):
|
13 |
+
def __init__(self, inChannel=1, outChannel=9, batchNorm=True):
|
14 |
+
super(SpixelSeg, self).__init__()
|
15 |
+
self.net = SpixelNet(inChannel=inChannel, outChannel=outChannel, batchNorm=batchNorm)
|
16 |
+
|
17 |
+
def get_trainable_params(self, lr=1.0):
|
18 |
+
#print('=> [optimizer] finetune backbone with smaller lr')
|
19 |
+
params = []
|
20 |
+
for name, param in self.named_parameters():
|
21 |
+
if 'xxx' in name:
|
22 |
+
params.append({'params': param, 'lr': lr})
|
23 |
+
else:
|
24 |
+
params.append({'params': param})
|
25 |
+
return params
|
26 |
+
|
27 |
+
def forward(self, input_grays):
|
28 |
+
pred_probs = self.net(input_grays)
|
29 |
+
return pred_probs
|
30 |
+
|
31 |
+
|
32 |
+
class AnchorColorProb(nn.Module):
|
33 |
+
def __init__(self, inChannel=1, outChannel=313, sp_size=16, d_model=64, use_dense_pos=True, spix_pos=False, learning_pos=False, \
|
34 |
+
random_hint=False, hint2regress=False, enhanced=False, use_mask=False, rank=0, colorLabeler=None):
|
35 |
+
super(AnchorColorProb, self).__init__()
|
36 |
+
self.sp_size = sp_size
|
37 |
+
self.spix_pos = spix_pos
|
38 |
+
self.use_token_mask = use_mask
|
39 |
+
self.hint2regress = hint2regress
|
40 |
+
self.segnet = SpixelSeg(inChannel=1, outChannel=9, batchNorm=True)
|
41 |
+
self.repnet = ColorProbNet(inChannel=inChannel, outChannel=64)
|
42 |
+
self.enhanced = enhanced
|
43 |
+
if self.enhanced:
|
44 |
+
self.enhanceNet = HourGlass2(inChannel=64+1, outChannel=2, resNum=3, normLayer=nn.BatchNorm2d)
|
45 |
+
|
46 |
+
## transformer architecture
|
47 |
+
self.n_vocab = 313
|
48 |
+
d_model, dim_feedforward, nhead = d_model, 4*d_model, 8
|
49 |
+
dropout, activation = 0.1, "relu"
|
50 |
+
n_enc_layers, n_dec_layers = 6, 6
|
51 |
+
enc_layer = EncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, use_dense_pos)
|
52 |
+
self.wildpath = TransformerEncoder(enc_layer, n_enc_layers, use_dense_pos)
|
53 |
+
self.hintpath = TransformerEncoder(enc_layer, n_enc_layers, use_dense_pos)
|
54 |
+
if self.spix_pos:
|
55 |
+
n_pos_x, n_pos_y = 256, 256
|
56 |
+
else:
|
57 |
+
n_pos_x, n_pos_y = 256//sp_size, 16//sp_size
|
58 |
+
self.pos_enc = build_position_encoding(d_model//2, n_pos_x, n_pos_y, is_learned=False)
|
59 |
+
|
60 |
+
self.mid_word_prj = nn.Linear(d_model, self.n_vocab, bias=False)
|
61 |
+
if self.hint2regress:
|
62 |
+
self.trg_word_emb = nn.Linear(d_model+2+1, d_model, bias=False)
|
63 |
+
self.trg_word_prj = nn.Linear(d_model, 2, bias=False)
|
64 |
+
else:
|
65 |
+
self.trg_word_emb = nn.Linear(d_model+self.n_vocab+1, d_model, bias=False)
|
66 |
+
self.trg_word_prj = nn.Linear(d_model, self.n_vocab, bias=False)
|
67 |
+
|
68 |
+
self.colorLabeler = colorLabeler
|
69 |
+
anchor_mode = 'random' if random_hint else 'clustering'
|
70 |
+
self.anchorGen = anchor_gen.AnchorAnalysis(mode=anchor_mode, colorLabeler=self.colorLabeler)
|
71 |
+
self._reset_parameters()
|
72 |
+
|
73 |
+
def _reset_parameters(self):
|
74 |
+
for p in self.parameters():
|
75 |
+
if p.dim() > 1:
|
76 |
+
nn.init.xavier_uniform_(p)
|
77 |
+
|
78 |
+
def load_and_froze_weight(self, checkpt_path):
|
79 |
+
data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
|
80 |
+
'''
|
81 |
+
for param_tensor in data_dict['state_dict']:
|
82 |
+
print(param_tensor,'\t',data_dict['state_dict'][param_tensor].size())
|
83 |
+
'''
|
84 |
+
self.segnet.load_state_dict(data_dict['state_dict'])
|
85 |
+
for name, param in self.segnet.named_parameters():
|
86 |
+
param.requires_grad = False
|
87 |
+
self.segnet.eval()
|
88 |
+
|
89 |
+
def set_train(self):
|
90 |
+
## running mode only affect certain modules, e.g. Dropout, BN, etc.
|
91 |
+
self.repnet.train()
|
92 |
+
self.wildpath.train()
|
93 |
+
self.hintpath.train()
|
94 |
+
if self.enhanced:
|
95 |
+
self.enhanceNet.train()
|
96 |
+
|
97 |
+
def get_entry_mask(self, mask_tensor):
|
98 |
+
if mask_tensor is None:
|
99 |
+
return None
|
100 |
+
## flatten (N,1,H,W) to (N,HW)
|
101 |
+
return mask_tensor.flatten(1)
|
102 |
+
|
103 |
+
def forward(self, input_grays, input_colors, n_anchors=8, sampled_T=0):
|
104 |
+
'''
|
105 |
+
Notice: function was customized for inferece only
|
106 |
+
'''
|
107 |
+
affinity_map = self.segnet(input_grays)
|
108 |
+
pred_feats = self.repnet(input_grays)
|
109 |
+
if self.spix_pos:
|
110 |
+
full_pos_feats = self.pos_enc(pred_feats)
|
111 |
+
proxy_feats = torch.cat([pred_feats, input_colors, full_pos_feats], dim=1)
|
112 |
+
pooled_proxy_feats, conf_sum = basic.poolfeat(proxy_feats, affinity_map, self.sp_size, self.sp_size, True)
|
113 |
+
feat_tokens = pooled_proxy_feats[:,:64,:,:]
|
114 |
+
spix_colors = pooled_proxy_feats[:,64:66,:,:]
|
115 |
+
pos_feats = pooled_proxy_feats[:,66:,:,:]
|
116 |
+
else:
|
117 |
+
proxy_feats = torch.cat([pred_feats, input_colors], dim=1)
|
118 |
+
pooled_proxy_feats, conf_sum = basic.poolfeat(proxy_feats, affinity_map, self.sp_size, self.sp_size, True)
|
119 |
+
feat_tokens = pooled_proxy_feats[:,:64,:,:]
|
120 |
+
spix_colors = pooled_proxy_feats[:,64:,:,:]
|
121 |
+
pos_feats = self.pos_enc(feat_tokens)
|
122 |
+
|
123 |
+
token_labels = torch.max(self.colorLabeler.encode_ab2ind(spix_colors), dim=1, keepdim=True)[1]
|
124 |
+
spixel_sizes = basic.get_spixel_size(affinity_map, self.sp_size, self.sp_size)
|
125 |
+
all_one_map = torch.ones(spixel_sizes.shape, device=input_grays.device)
|
126 |
+
empty_entries = torch.where(spixel_sizes < 25/(self.sp_size**2), all_one_map, 1-all_one_map)
|
127 |
+
src_pad_mask = self.get_entry_mask(empty_entries) if self.use_token_mask else None
|
128 |
+
trg_pad_mask = src_pad_mask
|
129 |
+
|
130 |
+
## parallel prob
|
131 |
+
N,C,H,W = feat_tokens.shape
|
132 |
+
## (N,C,H,W) -> (HW,N,C)
|
133 |
+
src_pos_seq = pos_feats.flatten(2).permute(2, 0, 1)
|
134 |
+
src_seq = feat_tokens.flatten(2).permute(2, 0, 1)
|
135 |
+
## color prob branch
|
136 |
+
enc_out, _ = self.wildpath(src_seq, src_pos_seq, src_pad_mask)
|
137 |
+
pal_logit = self.mid_word_prj(enc_out)
|
138 |
+
pal_logit = pal_logit.permute(1, 2, 0).view(N,self.n_vocab,H,W)
|
139 |
+
|
140 |
+
## seed prob branch
|
141 |
+
## mask(N,1,H,W): sample anchors at clustering layers
|
142 |
+
color_feat = enc_out.permute(1, 2, 0).view(N,C,H,W)
|
143 |
+
hint_mask, cluster_mask = self.anchorGen(color_feat, n_anchors, spixel_sizes, use_sklearn_kmeans=False)
|
144 |
+
pred_prob = torch.softmax(pal_logit, dim=1)
|
145 |
+
color_feat2 = src_seq.permute(1, 2, 0).view(N,C,H,W)
|
146 |
+
#pred_prob, adj_matrix = self.anchorGen._detect_correlation(color_feat, pred_prob, hint_mask, thres=0.1)
|
147 |
+
if sampled_T < 0:
|
148 |
+
## GT anchor colors
|
149 |
+
sampled_spix_colors = spix_colors
|
150 |
+
elif sampled_T > 0:
|
151 |
+
top1_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=0)
|
152 |
+
top2_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=1)
|
153 |
+
top3_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=2)
|
154 |
+
## duplicate meta tensors
|
155 |
+
sampled_spix_colors = torch.cat((top1_spix_colors,top2_spix_colors,top3_spix_colors), dim=0)
|
156 |
+
N = 3*N
|
157 |
+
input_grays = input_grays.expand(N,-1,-1,-1)
|
158 |
+
hint_mask = hint_mask.expand(N,-1,-1,-1)
|
159 |
+
affinity_map = affinity_map.expand(N,-1,-1,-1)
|
160 |
+
src_seq = src_seq.expand(-1, N,-1)
|
161 |
+
src_pos_seq = src_pos_seq.expand(-1, N,-1)
|
162 |
+
else:
|
163 |
+
sampled_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=sampled_T)
|
164 |
+
## debug: controllable
|
165 |
+
if False:
|
166 |
+
hint_mask, sampled_spix_colors = basic.io_user_control(hint_mask, spix_colors, output=False)
|
167 |
+
|
168 |
+
sampled_token_labels = torch.max(self.colorLabeler.encode_ab2ind(sampled_spix_colors), dim=1, keepdim=True)[1]
|
169 |
+
|
170 |
+
## hint based prediction
|
171 |
+
## (N,C,H,W) -> (HW,N,C)
|
172 |
+
mask_seq = hint_mask.flatten(2).permute(2, 0, 1)
|
173 |
+
if self.hint2regress:
|
174 |
+
spix_colors_ = sampled_spix_colors
|
175 |
+
gt_seq = spix_colors_.flatten(2).permute(2, 0, 1)
|
176 |
+
hint_seq = self.trg_word_emb(torch.cat([src_seq, mask_seq * gt_seq, mask_seq], dim=2))
|
177 |
+
dec_out, _ = self.hintpath(hint_seq, src_pos_seq, src_pad_mask)
|
178 |
+
else:
|
179 |
+
token_labels_ = sampled_token_labels
|
180 |
+
label_map = F.one_hot(token_labels_, num_classes=313).squeeze(1).float()
|
181 |
+
label_seq = label_map.permute(0, 3, 1, 2).flatten(2).permute(2, 0, 1)
|
182 |
+
hint_seq = self.trg_word_emb(torch.cat([src_seq, mask_seq * label_seq, mask_seq], dim=2))
|
183 |
+
dec_out, _ = self.hintpath(hint_seq, src_pos_seq, src_pad_mask)
|
184 |
+
ref_logit = self.trg_word_prj(dec_out)
|
185 |
+
Ct = 2 if self.hint2regress else self.n_vocab
|
186 |
+
ref_logit = ref_logit.permute(1, 2, 0).view(N,Ct,H,W)
|
187 |
+
|
188 |
+
## pixelwise enhancement
|
189 |
+
pred_colors = None
|
190 |
+
if self.enhanced:
|
191 |
+
proc_feats = dec_out.permute(1, 2, 0).view(N,64,H,W)
|
192 |
+
full_feats = basic.upfeat(proc_feats, affinity_map, self.sp_size, self.sp_size)
|
193 |
+
pred_colors = self.enhanceNet(torch.cat((input_grays,full_feats), dim=1))
|
194 |
+
pred_colors = torch.tanh(pred_colors)
|
195 |
+
|
196 |
+
return pal_logit, ref_logit, pred_colors, affinity_map, spix_colors, hint_mask
|