Mehdi Cherti commited on
Commit
7060b15
1 Parent(s): a45817e

update models

Browse files
Files changed (2) hide show
  1. clip_encoder.py +21 -0
  2. run.py +14 -1
clip_encoder.py CHANGED
@@ -62,3 +62,24 @@ class CLIPImageEncoder(nn.Module):
62
 
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
 
65
+ class OpenCLIPImageEncoder(nn.Module):
66
+
67
+ def __init__(self, model="ViT-B/32", pretrained="openai"):
68
+ super().__init__()
69
+ model, _, preprocess = open_clip.create_model_and_transforms(model, pretrained=pretrained)
70
+ self.tokenizer = open_clip.get_tokenizer(model)
71
+ CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
72
+ CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
73
+ mean = torch.tensor(CLIP_MEAN).view(1, 3, 1, 1)
74
+ std = torch.tensor(CLIP_STD).view(1, 3, 1, 1)
75
+ self.register_buffer("mean", mean)
76
+ self.register_buffer("std", std)
77
+
78
+ def forward_image(self, x):
79
+ x = torch.nn.functional.interpolate(x, mode='bicubic', size=(224, 224))
80
+ x = (x-self.mean)/self.std
81
+ return self.model.encode_image(x)
82
+
83
+ def forward_text(self, texts):
84
+ toks = self.tokenizer.tokenize(texts, truncate=True).to(self.mean.device)
85
+ return self.model.encode_text(toks)
run.py CHANGED
@@ -237,7 +237,7 @@ def ddgan_laion2b_v2():
237
  return cfg
238
 
239
  def ddgan_ddb_v1():
240
- cfg = ddgan_sd_v9()
241
  return cfg
242
 
243
  def ddgan_sd_v11():
@@ -245,6 +245,17 @@ def ddgan_sd_v11():
245
  cfg['model']['image_size'] = 512
246
  return cfg
247
 
 
 
 
 
 
 
 
 
 
 
 
248
  models = [
249
  ddgan_cifar10_cond17, # cifar10, cross attn for discr
250
  ddgan_cifar10_cond18, # cifar10, xl encoder
@@ -286,6 +297,8 @@ models = [
286
  ddgan_sd_v11,
287
  ddgan_laion2b_v2,
288
  ddgan_ddb_v1,
 
 
289
  ]
290
 
291
  def get_model(model_name):
 
237
  return cfg
238
 
239
  def ddgan_ddb_v1():
240
+ cfg = ddgan_sd_v10()
241
  return cfg
242
 
243
  def ddgan_sd_v11():
 
245
  cfg['model']['image_size'] = 512
246
  return cfg
247
 
248
+ def ddgan_ddb_v2():
249
+ cfg = ddgan_ddb_v1()
250
+ cfg['model']['num_timesteps'] = 1
251
+ return cfg
252
+
253
+ def ddgan_ddb_v3():
254
+ cfg = ddgan_ddb_v1()
255
+ cfg['model']['num_channels_dae'] = 192
256
+ cfg['model']['num_timesteps'] = 2
257
+ return cfg
258
+
259
  models = [
260
  ddgan_cifar10_cond17, # cifar10, cross attn for discr
261
  ddgan_cifar10_cond18, # cifar10, xl encoder
 
297
  ddgan_sd_v11,
298
  ddgan_laion2b_v2,
299
  ddgan_ddb_v1,
300
+ ddgan_ddb_v2,
301
+ ddgan_ddb_v3
302
  ]
303
 
304
  def get_model(model_name):