LimiTrix commited on
Commit
f63cb50
·
1 Parent(s): e4d33ac
Files changed (1) hide show
  1. mar.py +316 -5
mar.py CHANGED
@@ -19,6 +19,316 @@ def mask_by_order(mask_len, order, bsz, seq_len):
19
  return masking
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class MAR(nn.Module):
23
  """ Masked Autoencoder with VisionTransformer backbone
24
  """
@@ -275,11 +585,10 @@ class MAR(nn.Module):
275
  print(cur_tokens.shape)
276
 
277
  # class embedding and CFG
278
- # if labels is not None:
279
- # class_embedding = self.class_emb(labels)
280
- # else:
281
- #
282
- class_embedding = self.fake_latent.repeat(bsz, 1)
283
  if not cfg == 1.0:
284
  tokens = torch.cat([tokens, tokens], dim=0)
285
  class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
@@ -327,7 +636,9 @@ class MAR(nn.Module):
327
  tokens = cur_tokens.clone()
328
 
329
  # unpatchify
 
330
  tokens = self.unpatchify(tokens)
 
331
  return tokens
332
 
333
 
 
19
  return masking
20
 
21
 
22
+ class MARBert(nn.Module):
23
+ """ Masked Autoencoder with VisionTransformer backbone
24
+ """
25
+ def __init__(self, img_size=256, vae_stride=16, patch_size=1,
26
+ encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
27
+ decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
28
+ mlp_ratio=4., norm_layer=nn.LayerNorm,
29
+ vae_embed_dim=16,
30
+ mask_ratio_min=0.7,
31
+ label_drop_prob=0.1,
32
+ class_num=1000,
33
+ attn_dropout=0.1,
34
+ proj_dropout=0.1,
35
+ buffer_size=64,
36
+ diffloss_d=3,
37
+ diffloss_w=1024,
38
+ num_sampling_steps='100',
39
+ diffusion_batch_mul=4,
40
+ grad_checkpointing=False,
41
+ ):
42
+ super().__init__()
43
+
44
+ # --------------------------------------------------------------------------
45
+ # VAE and patchify specifics
46
+ self.vae_embed_dim = vae_embed_dim
47
+
48
+ self.img_size = img_size
49
+ self.vae_stride = vae_stride
50
+ self.patch_size = patch_size
51
+ self.seq_h = self.seq_w = img_size // vae_stride // patch_size
52
+ self.seq_len = self.seq_h * self.seq_w
53
+ self.token_embed_dim = vae_embed_dim * patch_size**2
54
+ self.grad_checkpointing = grad_checkpointing
55
+
56
+ # --------------------------------------------------------------------------
57
+ # Class Embedding
58
+ self.num_classes = class_num
59
+ self.class_emb = nn.Embedding(1000, encoder_embed_dim)
60
+ self.label_drop_prob = label_drop_prob
61
+ # Fake class embedding for CFG's unconditional generation
62
+ self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))
63
+
64
+ # --------------------------------------------------------------------------
65
+ # MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
66
+ self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)
67
+
68
+ # --------------------------------------------------------------------------
69
+ # MAR encoder specifics
70
+ self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
71
+ self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
72
+ self.buffer_size = buffer_size
73
+ self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim))
74
+
75
+ self.encoder_blocks = nn.ModuleList([
76
+ Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
77
+ proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
78
+ self.encoder_norm = norm_layer(encoder_embed_dim)
79
+
80
+ # --------------------------------------------------------------------------
81
+ # MAR decoder specifics
82
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
83
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
84
+ self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))
85
+
86
+ self.decoder_blocks = nn.ModuleList([
87
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
88
+ proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])
89
+
90
+ self.decoder_norm = norm_layer(decoder_embed_dim)
91
+ self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))
92
+
93
+ self.initialize_weights()
94
+
95
+ # --------------------------------------------------------------------------
96
+ # Diffusion Loss
97
+ self.diffloss = DiffLoss(
98
+ target_channels=self.token_embed_dim,
99
+ z_channels=decoder_embed_dim,
100
+ width=diffloss_w,
101
+ depth=diffloss_d,
102
+ num_sampling_steps=num_sampling_steps,
103
+ grad_checkpointing=grad_checkpointing
104
+ )
105
+ self.diffusion_batch_mul = diffusion_batch_mul
106
+
107
+ def initialize_weights(self):
108
+ # parameters
109
+ torch.nn.init.normal_(self.class_emb.weight, std=.02)
110
+ torch.nn.init.normal_(self.fake_latent, std=.02)
111
+ torch.nn.init.normal_(self.mask_token, std=.02)
112
+ torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
113
+ torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
114
+ torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)
115
+
116
+ # initialize nn.Linear and nn.LayerNorm
117
+ self.apply(self._init_weights)
118
+
119
+ def _init_weights(self, m):
120
+ if isinstance(m, nn.Linear):
121
+ # we use xavier_uniform following official JAX ViT:
122
+ torch.nn.init.xavier_uniform_(m.weight)
123
+ if isinstance(m, nn.Linear) and m.bias is not None:
124
+ nn.init.constant_(m.bias, 0)
125
+ elif isinstance(m, nn.LayerNorm):
126
+ if m.bias is not None:
127
+ nn.init.constant_(m.bias, 0)
128
+ if m.weight is not None:
129
+ nn.init.constant_(m.weight, 1.0)
130
+
131
+ def patchify(self, x):
132
+ bsz, c, h, w = x.shape
133
+ p = self.patch_size
134
+ h_, w_ = h // p, w // p
135
+
136
+ x = x.reshape(bsz, c, h_, p, w_, p)
137
+ x = torch.einsum('nchpwq->nhwcpq', x)
138
+ x = x.reshape(bsz, h_ * w_, c * p ** 2)
139
+ return x # [n, l, d]
140
+
141
+ def unpatchify(self, x):
142
+ bsz = x.shape[0]
143
+ p = self.patch_size
144
+ c = self.vae_embed_dim
145
+ h_, w_ = self.seq_h, self.seq_w
146
+
147
+ x = x.reshape(bsz, h_, w_, c, p, p)
148
+ x = torch.einsum('nhwcpq->nchpwq', x)
149
+ x = x.reshape(bsz, c, h_ * p, w_ * p)
150
+ return x # [n, c, h, w]
151
+
152
+ def sample_orders(self, bsz):
153
+ # generate a batch of random generation orders
154
+ orders = []
155
+ for _ in range(bsz):
156
+ order = np.array(list(range(self.seq_len)))
157
+ np.random.shuffle(order)
158
+ orders.append(order)
159
+ orders = torch.Tensor(np.array(orders)).cuda().long()
160
+ return orders
161
+
162
+ def random_masking(self, x, orders):
163
+ # generate token mask
164
+ bsz, seq_len, embed_dim = x.shape
165
+ mask_rate = self.mask_ratio_generator.rvs(1)[0]
166
+ num_masked_tokens = int(np.ceil(seq_len * mask_rate))
167
+ mask = torch.zeros(bsz, seq_len, device=x.device)
168
+ mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
169
+ src=torch.ones(bsz, seq_len, device=x.device))
170
+ return mask
171
+
172
+ def forward_mae_encoder(self, x, mask, class_embedding):
173
+ x = self.z_proj(x)
174
+ bsz, seq_len, embed_dim = x.shape
175
+
176
+ # concat buffer
177
+ x = torch.cat([torch.zeros(bsz, self.buffer_size, embed_dim, device=x.device), x], dim=1)
178
+ mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
179
+
180
+ # random drop class embedding during training
181
+ if self.training:
182
+ drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
183
+ drop_latent_mask = drop_latent_mask.unsqueeze(-1).cuda().to(x.dtype)
184
+ class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
185
+
186
+ x[:, :self.buffer_size] = class_embedding.unsqueeze(1)
187
+
188
+ # encoder position embedding
189
+ x = x + self.encoder_pos_embed_learned
190
+ x = self.z_proj_ln(x)
191
+
192
+ # dropping
193
+ x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
194
+
195
+ # apply Transformer blocks
196
+ if self.grad_checkpointing and not torch.jit.is_scripting():
197
+ for block in self.encoder_blocks:
198
+ x = checkpoint(block, x)
199
+ else:
200
+ for block in self.encoder_blocks:
201
+ x = block(x)
202
+ x = self.encoder_norm(x)
203
+
204
+ return x
205
+
206
+ def forward_mae_decoder(self, x, mask):
207
+
208
+ x = self.decoder_embed(x)
209
+ mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
210
+
211
+ # pad mask tokens
212
+ mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
213
+ x_after_pad = mask_tokens.clone()
214
+ x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
215
+
216
+ # decoder position embedding
217
+ x = x_after_pad + self.decoder_pos_embed_learned
218
+
219
+ # apply Transformer blocks
220
+ if self.grad_checkpointing and not torch.jit.is_scripting():
221
+ for block in self.decoder_blocks:
222
+ x = checkpoint(block, x)
223
+ else:
224
+ for block in self.decoder_blocks:
225
+ x = block(x)
226
+ x = self.decoder_norm(x)
227
+
228
+ x = x[:, self.buffer_size:]
229
+ x = x + self.diffusion_pos_embed_learned
230
+ return x
231
+
232
+ def forward_loss(self, z, target, mask):
233
+ bsz, seq_len, _ = target.shape
234
+ target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
235
+ z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
236
+ mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul)
237
+ loss = self.diffloss(z=z, target=target, mask=mask)
238
+ return loss
239
+
240
+ def forward(self, imgs, labels):
241
+
242
+ # class embed
243
+ class_embedding = self.class_emb(labels)
244
+
245
+ # patchify and mask (drop) tokens
246
+ x = self.patchify(imgs)
247
+ gt_latents = x.clone().detach()
248
+ orders = self.sample_orders(bsz=x.size(0))
249
+ mask = self.random_masking(x, orders)
250
+
251
+ # mae encoder
252
+ x = self.forward_mae_encoder(x, mask, class_embedding)
253
+
254
+ # mae decoder
255
+ z = self.forward_mae_decoder(x, mask)
256
+
257
+ # diffloss
258
+ loss = self.forward_loss(z=z, target=gt_latents, mask=mask)
259
+
260
+ return loss
261
+
262
+ def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
263
+
264
+ # init and sample generation orders
265
+ mask = torch.ones(bsz, self.seq_len).cuda()
266
+ tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
267
+ orders = self.sample_orders(bsz)
268
+
269
+ indices = list(range(num_iter))
270
+ if progress:
271
+ indices = tqdm(indices)
272
+ # generate latents
273
+ for step in indices:
274
+ cur_tokens = tokens.clone()
275
+ print(cur_tokens.shape)
276
+
277
+ # class embedding and CFG
278
+ if labels is not None:
279
+ class_embedding = self.class_emb(labels)
280
+ else:
281
+ class_embedding = self.fake_latent.repeat(bsz, 1)
282
+ if not cfg == 1.0:
283
+ tokens = torch.cat([tokens, tokens], dim=0)
284
+ class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
285
+ mask = torch.cat([mask, mask], dim=0)
286
+
287
+ # mae encoder
288
+ x = self.forward_mae_encoder(tokens, mask, class_embedding)
289
+
290
+ # mae decoder
291
+ z = self.forward_mae_decoder(x, mask)
292
+
293
+ # mask ratio for the next round, following MaskGIT and MAGE.
294
+ mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
295
+ mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda()
296
+
297
+ # masks out at least one for the next iteration
298
+ mask_len = torch.maximum(torch.Tensor([1]).cuda(),
299
+ torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
300
+
301
+ # get masking for next iteration and locations to be predicted in this iteration
302
+ mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
303
+ if step >= num_iter - 1:
304
+ mask_to_pred = mask[:bsz].bool()
305
+ else:
306
+ mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
307
+ mask = mask_next
308
+ if not cfg == 1.0:
309
+ mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
310
+
311
+ # sample token latents for this step
312
+ z = z[mask_to_pred.nonzero(as_tuple=True)]
313
+ # cfg schedule follow Muse
314
+ if cfg_schedule == "linear":
315
+ cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
316
+ elif cfg_schedule == "constant":
317
+ cfg_iter = cfg
318
+ else:
319
+ raise NotImplementedError
320
+ sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)
321
+ if not cfg == 1.0:
322
+ sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
323
+ mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
324
+
325
+ cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
326
+ tokens = cur_tokens.clone()
327
+
328
+ # unpatchify
329
+ tokens = self.unpatchify(tokens)
330
+ return tokens
331
+
332
  class MAR(nn.Module):
333
  """ Masked Autoencoder with VisionTransformer backbone
334
  """
 
585
  print(cur_tokens.shape)
586
 
587
  # class embedding and CFG
588
+ if labels is not None:
589
+ class_embedding = self.class_emb(labels)
590
+ else:
591
+ class_embedding = self.fake_latent.repeat(bsz, 1)
 
592
  if not cfg == 1.0:
593
  tokens = torch.cat([tokens, tokens], dim=0)
594
  class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
 
636
  tokens = cur_tokens.clone()
637
 
638
  # unpatchify
639
+ print(tokens.shape)
640
  tokens = self.unpatchify(tokens)
641
+ print(tokens.shape)
642
  return tokens
643
 
644