doevent commited on
Commit
f4de781
·
1 Parent(s): 69269d9

Upload models/model.py

Browse files
Files changed (1) hide show
  1. models/model.py +196 -0
models/model.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from models.network import HourGlass2, SpixelNet, ColorProbNet
5
+ from models.transformer2d import EncoderLayer, DecoderLayer, TransformerEncoder, TransformerDecoder
6
+ from models.position_encoding import build_position_encoding
7
+ from models import basic, clusterkit, anchor_gen
8
+ from collections import OrderedDict
9
+ from utils import util, cielab
10
+
11
+
12
+ class SpixelSeg(nn.Module):
13
+ def __init__(self, inChannel=1, outChannel=9, batchNorm=True):
14
+ super(SpixelSeg, self).__init__()
15
+ self.net = SpixelNet(inChannel=inChannel, outChannel=outChannel, batchNorm=batchNorm)
16
+
17
+ def get_trainable_params(self, lr=1.0):
18
+ #print('=> [optimizer] finetune backbone with smaller lr')
19
+ params = []
20
+ for name, param in self.named_parameters():
21
+ if 'xxx' in name:
22
+ params.append({'params': param, 'lr': lr})
23
+ else:
24
+ params.append({'params': param})
25
+ return params
26
+
27
+ def forward(self, input_grays):
28
+ pred_probs = self.net(input_grays)
29
+ return pred_probs
30
+
31
+
32
+ class AnchorColorProb(nn.Module):
33
+ def __init__(self, inChannel=1, outChannel=313, sp_size=16, d_model=64, use_dense_pos=True, spix_pos=False, learning_pos=False, \
34
+ random_hint=False, hint2regress=False, enhanced=False, use_mask=False, rank=0, colorLabeler=None):
35
+ super(AnchorColorProb, self).__init__()
36
+ self.sp_size = sp_size
37
+ self.spix_pos = spix_pos
38
+ self.use_token_mask = use_mask
39
+ self.hint2regress = hint2regress
40
+ self.segnet = SpixelSeg(inChannel=1, outChannel=9, batchNorm=True)
41
+ self.repnet = ColorProbNet(inChannel=inChannel, outChannel=64)
42
+ self.enhanced = enhanced
43
+ if self.enhanced:
44
+ self.enhanceNet = HourGlass2(inChannel=64+1, outChannel=2, resNum=3, normLayer=nn.BatchNorm2d)
45
+
46
+ ## transformer architecture
47
+ self.n_vocab = 313
48
+ d_model, dim_feedforward, nhead = d_model, 4*d_model, 8
49
+ dropout, activation = 0.1, "relu"
50
+ n_enc_layers, n_dec_layers = 6, 6
51
+ enc_layer = EncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, use_dense_pos)
52
+ self.wildpath = TransformerEncoder(enc_layer, n_enc_layers, use_dense_pos)
53
+ self.hintpath = TransformerEncoder(enc_layer, n_enc_layers, use_dense_pos)
54
+ if self.spix_pos:
55
+ n_pos_x, n_pos_y = 256, 256
56
+ else:
57
+ n_pos_x, n_pos_y = 256//sp_size, 16//sp_size
58
+ self.pos_enc = build_position_encoding(d_model//2, n_pos_x, n_pos_y, is_learned=False)
59
+
60
+ self.mid_word_prj = nn.Linear(d_model, self.n_vocab, bias=False)
61
+ if self.hint2regress:
62
+ self.trg_word_emb = nn.Linear(d_model+2+1, d_model, bias=False)
63
+ self.trg_word_prj = nn.Linear(d_model, 2, bias=False)
64
+ else:
65
+ self.trg_word_emb = nn.Linear(d_model+self.n_vocab+1, d_model, bias=False)
66
+ self.trg_word_prj = nn.Linear(d_model, self.n_vocab, bias=False)
67
+
68
+ self.colorLabeler = colorLabeler
69
+ anchor_mode = 'random' if random_hint else 'clustering'
70
+ self.anchorGen = anchor_gen.AnchorAnalysis(mode=anchor_mode, colorLabeler=self.colorLabeler)
71
+ self._reset_parameters()
72
+
73
+ def _reset_parameters(self):
74
+ for p in self.parameters():
75
+ if p.dim() > 1:
76
+ nn.init.xavier_uniform_(p)
77
+
78
+ def load_and_froze_weight(self, checkpt_path):
79
+ data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
80
+ '''
81
+ for param_tensor in data_dict['state_dict']:
82
+ print(param_tensor,'\t',data_dict['state_dict'][param_tensor].size())
83
+ '''
84
+ self.segnet.load_state_dict(data_dict['state_dict'])
85
+ for name, param in self.segnet.named_parameters():
86
+ param.requires_grad = False
87
+ self.segnet.eval()
88
+
89
+ def set_train(self):
90
+ ## running mode only affect certain modules, e.g. Dropout, BN, etc.
91
+ self.repnet.train()
92
+ self.wildpath.train()
93
+ self.hintpath.train()
94
+ if self.enhanced:
95
+ self.enhanceNet.train()
96
+
97
+ def get_entry_mask(self, mask_tensor):
98
+ if mask_tensor is None:
99
+ return None
100
+ ## flatten (N,1,H,W) to (N,HW)
101
+ return mask_tensor.flatten(1)
102
+
103
+ def forward(self, input_grays, input_colors, n_anchors=8, sampled_T=0):
104
+ '''
105
+ Notice: function was customized for inferece only
106
+ '''
107
+ affinity_map = self.segnet(input_grays)
108
+ pred_feats = self.repnet(input_grays)
109
+ if self.spix_pos:
110
+ full_pos_feats = self.pos_enc(pred_feats)
111
+ proxy_feats = torch.cat([pred_feats, input_colors, full_pos_feats], dim=1)
112
+ pooled_proxy_feats, conf_sum = basic.poolfeat(proxy_feats, affinity_map, self.sp_size, self.sp_size, True)
113
+ feat_tokens = pooled_proxy_feats[:,:64,:,:]
114
+ spix_colors = pooled_proxy_feats[:,64:66,:,:]
115
+ pos_feats = pooled_proxy_feats[:,66:,:,:]
116
+ else:
117
+ proxy_feats = torch.cat([pred_feats, input_colors], dim=1)
118
+ pooled_proxy_feats, conf_sum = basic.poolfeat(proxy_feats, affinity_map, self.sp_size, self.sp_size, True)
119
+ feat_tokens = pooled_proxy_feats[:,:64,:,:]
120
+ spix_colors = pooled_proxy_feats[:,64:,:,:]
121
+ pos_feats = self.pos_enc(feat_tokens)
122
+
123
+ token_labels = torch.max(self.colorLabeler.encode_ab2ind(spix_colors), dim=1, keepdim=True)[1]
124
+ spixel_sizes = basic.get_spixel_size(affinity_map, self.sp_size, self.sp_size)
125
+ all_one_map = torch.ones(spixel_sizes.shape, device=input_grays.device)
126
+ empty_entries = torch.where(spixel_sizes < 25/(self.sp_size**2), all_one_map, 1-all_one_map)
127
+ src_pad_mask = self.get_entry_mask(empty_entries) if self.use_token_mask else None
128
+ trg_pad_mask = src_pad_mask
129
+
130
+ ## parallel prob
131
+ N,C,H,W = feat_tokens.shape
132
+ ## (N,C,H,W) -> (HW,N,C)
133
+ src_pos_seq = pos_feats.flatten(2).permute(2, 0, 1)
134
+ src_seq = feat_tokens.flatten(2).permute(2, 0, 1)
135
+ ## color prob branch
136
+ enc_out, _ = self.wildpath(src_seq, src_pos_seq, src_pad_mask)
137
+ pal_logit = self.mid_word_prj(enc_out)
138
+ pal_logit = pal_logit.permute(1, 2, 0).view(N,self.n_vocab,H,W)
139
+
140
+ ## seed prob branch
141
+ ## mask(N,1,H,W): sample anchors at clustering layers
142
+ color_feat = enc_out.permute(1, 2, 0).view(N,C,H,W)
143
+ hint_mask, cluster_mask = self.anchorGen(color_feat, n_anchors, spixel_sizes, use_sklearn_kmeans=False)
144
+ pred_prob = torch.softmax(pal_logit, dim=1)
145
+ color_feat2 = src_seq.permute(1, 2, 0).view(N,C,H,W)
146
+ #pred_prob, adj_matrix = self.anchorGen._detect_correlation(color_feat, pred_prob, hint_mask, thres=0.1)
147
+ if sampled_T < 0:
148
+ ## GT anchor colors
149
+ sampled_spix_colors = spix_colors
150
+ elif sampled_T > 0:
151
+ top1_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=0)
152
+ top2_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=1)
153
+ top3_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=2)
154
+ ## duplicate meta tensors
155
+ sampled_spix_colors = torch.cat((top1_spix_colors,top2_spix_colors,top3_spix_colors), dim=0)
156
+ N = 3*N
157
+ input_grays = input_grays.expand(N,-1,-1,-1)
158
+ hint_mask = hint_mask.expand(N,-1,-1,-1)
159
+ affinity_map = affinity_map.expand(N,-1,-1,-1)
160
+ src_seq = src_seq.expand(-1, N,-1)
161
+ src_pos_seq = src_pos_seq.expand(-1, N,-1)
162
+ else:
163
+ sampled_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=sampled_T)
164
+ ## debug: controllable
165
+ if False:
166
+ hint_mask, sampled_spix_colors = basic.io_user_control(hint_mask, spix_colors, output=False)
167
+
168
+ sampled_token_labels = torch.max(self.colorLabeler.encode_ab2ind(sampled_spix_colors), dim=1, keepdim=True)[1]
169
+
170
+ ## hint based prediction
171
+ ## (N,C,H,W) -> (HW,N,C)
172
+ mask_seq = hint_mask.flatten(2).permute(2, 0, 1)
173
+ if self.hint2regress:
174
+ spix_colors_ = sampled_spix_colors
175
+ gt_seq = spix_colors_.flatten(2).permute(2, 0, 1)
176
+ hint_seq = self.trg_word_emb(torch.cat([src_seq, mask_seq * gt_seq, mask_seq], dim=2))
177
+ dec_out, _ = self.hintpath(hint_seq, src_pos_seq, src_pad_mask)
178
+ else:
179
+ token_labels_ = sampled_token_labels
180
+ label_map = F.one_hot(token_labels_, num_classes=313).squeeze(1).float()
181
+ label_seq = label_map.permute(0, 3, 1, 2).flatten(2).permute(2, 0, 1)
182
+ hint_seq = self.trg_word_emb(torch.cat([src_seq, mask_seq * label_seq, mask_seq], dim=2))
183
+ dec_out, _ = self.hintpath(hint_seq, src_pos_seq, src_pad_mask)
184
+ ref_logit = self.trg_word_prj(dec_out)
185
+ Ct = 2 if self.hint2regress else self.n_vocab
186
+ ref_logit = ref_logit.permute(1, 2, 0).view(N,Ct,H,W)
187
+
188
+ ## pixelwise enhancement
189
+ pred_colors = None
190
+ if self.enhanced:
191
+ proc_feats = dec_out.permute(1, 2, 0).view(N,64,H,W)
192
+ full_feats = basic.upfeat(proc_feats, affinity_map, self.sp_size, self.sp_size)
193
+ pred_colors = self.enhanceNet(torch.cat((input_grays,full_feats), dim=1))
194
+ pred_colors = torch.tanh(pred_colors)
195
+
196
+ return pal_logit, ref_logit, pred_colors, affinity_map, spix_colors, hint_mask