Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
•
bc53ac3
1
Parent(s):
c81908d
support cond attn based discriminator
Browse files- pytorch_fid/fid_score.py +1 -1
- score_sde/models/discriminator.py +159 -0
pytorch_fid/fid_score.py
CHANGED
@@ -148,7 +148,7 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', resize
|
|
148 |
|
149 |
for batch in tqdm(dataloader):
|
150 |
batch = batch.to(device)
|
151 |
-
print(batch.shape, batch.min(), batch.max)
|
152 |
with torch.no_grad():
|
153 |
pred = model(batch)[0]
|
154 |
|
|
|
148 |
|
149 |
for batch in tqdm(dataloader):
|
150 |
batch = batch.to(device)
|
151 |
+
#print(batch.shape, batch.min(), batch.max)
|
152 |
with torch.no_grad():
|
153 |
pred = model(batch)[0]
|
154 |
|
score_sde/models/discriminator.py
CHANGED
@@ -167,6 +167,87 @@ class Discriminator_small(nn.Module):
|
|
167 |
|
168 |
return out
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
class Discriminator_large(nn.Module):
|
172 |
"""A time-dependent discriminator for large images (CelebA, LSUN)."""
|
@@ -239,3 +320,81 @@ class Discriminator_large(nn.Module):
|
|
239 |
out = self.end_linear(out) + (self.cond_proj(cond) * out).sum(dim=1, keepdim=True)
|
240 |
return out
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
return out
|
169 |
|
170 |
+
class SmallCondAttnDiscriminator(nn.Module):
|
171 |
+
"""A time-dependent discriminator for small images (CIFAR10, StackMNIST)."""
|
172 |
+
|
173 |
+
def __init__(self, nc = 3, ngf = 64, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768):
|
174 |
+
super().__init__()
|
175 |
+
# Gaussian random feature embedding layer for time
|
176 |
+
self.act = act
|
177 |
+
self.cond_attn = layers.CondAttnBlock(ngf*8, cond_size, dim_head=64, heads=8, norm_context=False, cosine_sim_attn=False)
|
178 |
+
|
179 |
+
self.t_embed = TimestepEmbedding(
|
180 |
+
embedding_dim=t_emb_dim,
|
181 |
+
hidden_dim=t_emb_dim,
|
182 |
+
output_dim=t_emb_dim,
|
183 |
+
act=act,
|
184 |
+
)
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
# Encoding layers where the resolution decreases
|
189 |
+
self.start_conv = conv2d(nc,ngf*2,1, padding=0)
|
190 |
+
self.conv1 = DownConvBlock(ngf*2, ngf*2, t_emb_dim = t_emb_dim,act=act)
|
191 |
+
|
192 |
+
self.conv2 = DownConvBlock(ngf*2, ngf*4, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
193 |
+
|
194 |
+
|
195 |
+
self.conv3 = DownConvBlock(ngf*4, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
196 |
+
|
197 |
+
|
198 |
+
self.conv4 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
199 |
+
|
200 |
+
|
201 |
+
self.final_conv = conv2d(ngf*8 + 1, ngf*8, 3,padding=1, init_scale=0.)
|
202 |
+
self.end_linear = dense(ngf*8, 1)
|
203 |
+
self.end_linear_cond = dense(ngf*8, 1)
|
204 |
+
#self.gn_cond = nn.GroupNorm(num_groups=32, num_channels=ngf*8, eps=1e-6)
|
205 |
+
|
206 |
+
self.stddev_group = 4
|
207 |
+
self.stddev_feat = 1
|
208 |
+
|
209 |
+
|
210 |
+
def forward(self, x, t, x_t, cond=None):
|
211 |
+
t_embed = self.t_embed(t)
|
212 |
+
# if cond is not None:
|
213 |
+
# t_embed = t_embed + self.cond_proj(cond)
|
214 |
+
t_embed = self.act(t_embed)
|
215 |
+
input_x = torch.cat((x, x_t), dim = 1)
|
216 |
+
|
217 |
+
h0 = self.start_conv(input_x)
|
218 |
+
h1 = self.conv1(h0,t_embed)
|
219 |
+
|
220 |
+
h2 = self.conv2(h1,t_embed)
|
221 |
+
|
222 |
+
h3 = self.conv3(h2,t_embed)
|
223 |
+
|
224 |
+
|
225 |
+
out = self.conv4(h3,t_embed)
|
226 |
+
|
227 |
+
batch, channel, height, width = out.shape
|
228 |
+
group = min(batch, self.stddev_group)
|
229 |
+
stddev = out.view(
|
230 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
231 |
+
)
|
232 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
233 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
234 |
+
stddev = stddev.repeat(group, 1, height, width)
|
235 |
+
out = torch.cat([out, stddev], 1)
|
236 |
+
|
237 |
+
out = self.final_conv(out)
|
238 |
+
out = self.act(out)
|
239 |
+
|
240 |
+
cond_pooled, cond, cond_mask = cond
|
241 |
+
|
242 |
+
out_cond = (self.cond_attn(out, cond, cond_mask))
|
243 |
+
|
244 |
+
out = out.view(out.shape[0], out.shape[1], -1).mean(2)
|
245 |
+
out_cond = out_cond.view(out_cond.shape[0], out_cond.shape[1], -1).mean(2)
|
246 |
+
out = self.end_linear(out) + self.end_linear_cond(out_cond)
|
247 |
+
return out
|
248 |
+
|
249 |
+
|
250 |
+
|
251 |
|
252 |
class Discriminator_large(nn.Module):
|
253 |
"""A time-dependent discriminator for large images (CelebA, LSUN)."""
|
|
|
320 |
out = self.end_linear(out) + (self.cond_proj(cond) * out).sum(dim=1, keepdim=True)
|
321 |
return out
|
322 |
|
323 |
+
|
324 |
+
class CondAttnDiscriminator(nn.Module):
|
325 |
+
"""A time-dependent discriminator for large images (CelebA, LSUN)."""
|
326 |
+
|
327 |
+
def __init__(self, nc = 1, ngf = 32, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768):
|
328 |
+
super().__init__()
|
329 |
+
# Gaussian random feature embedding layer for time
|
330 |
+
self.act = act
|
331 |
+
self.cond_attn = layers.CondAttnBlock(ngf*8, cond_size, dim_head=64, heads=8, norm_context=False, cosine_sim_attn=False)
|
332 |
+
|
333 |
+
self.t_embed = TimestepEmbedding(
|
334 |
+
embedding_dim=t_emb_dim,
|
335 |
+
hidden_dim=t_emb_dim,
|
336 |
+
output_dim=t_emb_dim,
|
337 |
+
act=act,
|
338 |
+
)
|
339 |
+
|
340 |
+
self.start_conv = conv2d(nc,ngf*2,1, padding=0)
|
341 |
+
self.conv1 = DownConvBlock(ngf*2, ngf*4, t_emb_dim = t_emb_dim, downsample = True, act=act)
|
342 |
+
|
343 |
+
self.conv2 = DownConvBlock(ngf*4, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
344 |
+
|
345 |
+
self.conv3 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
346 |
+
|
347 |
+
|
348 |
+
self.conv4 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
349 |
+
self.conv5 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
350 |
+
self.conv6 = DownConvBlock(ngf*8, ngf*8, t_emb_dim = t_emb_dim, downsample=True,act=act)
|
351 |
+
|
352 |
+
|
353 |
+
self.final_conv = conv2d(ngf*8 + 1, ngf*8, 3,padding=1)
|
354 |
+
self.end_linear = dense(ngf*8, 1)
|
355 |
+
self.end_linear_cond = dense(ngf*8, 1)
|
356 |
+
|
357 |
+
self.stddev_group = 4
|
358 |
+
self.stddev_feat = 1
|
359 |
+
|
360 |
+
|
361 |
+
def forward(self, x, t, x_t, cond=None):
|
362 |
+
cond_pooled, cond, cond_mask = cond
|
363 |
+
|
364 |
+
t_embed = self.t_embed(t)
|
365 |
+
t_embed = self.act(t_embed)
|
366 |
+
|
367 |
+
input_x = torch.cat((x, x_t), dim = 1)
|
368 |
+
|
369 |
+
h = self.start_conv(input_x)
|
370 |
+
h = self.conv1(h,t_embed)
|
371 |
+
|
372 |
+
h = self.conv2(h,t_embed)
|
373 |
+
|
374 |
+
h = self.conv3(h,t_embed)
|
375 |
+
h = self.conv4(h,t_embed)
|
376 |
+
h = self.conv5(h,t_embed)
|
377 |
+
|
378 |
+
|
379 |
+
out = self.conv6(h,t_embed)
|
380 |
+
|
381 |
+
batch, channel, height, width = out.shape
|
382 |
+
group = min(batch, self.stddev_group)
|
383 |
+
stddev = out.view(
|
384 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
385 |
+
)
|
386 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
387 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
388 |
+
stddev = stddev.repeat(group, 1, height, width)
|
389 |
+
out = torch.cat([out, stddev], 1)
|
390 |
+
|
391 |
+
out = self.final_conv(out)
|
392 |
+
out = self.act(out)
|
393 |
+
|
394 |
+
out_cond = self.cond_attn(out, cond, cond_mask)
|
395 |
+
|
396 |
+
|
397 |
+
out = out.view(out.shape[0], out.shape[1], -1).mean(2)
|
398 |
+
out_cond = out_cond.view(out_cond.shape[0], out_cond.shape[1], -1).mean(2)
|
399 |
+
out = self.end_linear(out) + self.end_linear_cond(out_cond)
|
400 |
+
return out
|