Mehdi Cherti commited on
Commit
c81908d
1 Parent(s): c334626

text to image support

Browse files
pytorch_fid/fid_score.py CHANGED
@@ -140,7 +140,7 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', resize
140
  batch_size=batch_size,
141
  shuffle=False,
142
  drop_last=False,
143
- num_workers=cpu_count())
144
 
145
  pred_arr = np.empty((len(files), dims))
146
 
@@ -148,7 +148,7 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', resize
148
 
149
  for batch in tqdm(dataloader):
150
  batch = batch.to(device)
151
-
152
  with torch.no_grad():
153
  pred = model(batch)[0]
154
 
 
140
  batch_size=batch_size,
141
  shuffle=False,
142
  drop_last=False,
143
+ num_workers=8)
144
 
145
  pred_arr = np.empty((len(files), dims))
146
 
 
148
 
149
  for batch in tqdm(dataloader):
150
  batch = batch.to(device)
151
+ print(batch.shape, batch.min(), batch.max)
152
  with torch.no_grad():
153
  pred = model(batch)[0]
154
 
run.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from clize import run
3
+ from glob import glob
4
+ from subprocess import call
5
+
6
+ def base():
7
+ return {
8
+ "slurm":{
9
+ "t": 360,
10
+ "N": 2,
11
+ "n": 8,
12
+ },
13
+ "model":{
14
+ "dataset" :"wds",
15
+ "dataset_root": "/p/scratch/ccstdl/cherti1/CC12M/{00000..01099}.tar",
16
+ "image_size": 256,
17
+ "num_channels": 3,
18
+ "num_channels_dae": 128,
19
+ "ch_mult": "1 1 2 2 4 4",
20
+ "num_timesteps": 4,
21
+ "num_res_blocks": 2,
22
+ "batch_size": 8,
23
+ "num_epoch": 1000,
24
+ "ngf": 64,
25
+ "embedding_type": "positional",
26
+ "use_ema": "",
27
+ "ema_decay": 0.999,
28
+ "r1_gamma": 1.0,
29
+ "z_emb_dim": 256,
30
+ "lr_d": 1e-4,
31
+ "lr_g": 1.6e-4,
32
+ "lazy_reg": 10,
33
+ "save_content": "",
34
+ "save_ckpt_every": 1,
35
+ "masked_mean": "",
36
+ "resume": "",
37
+ }
38
+ }
39
+ def ddgan_cc12m_v2():
40
+ cfg = base()
41
+ cfg['slurm']['N'] = 2
42
+ cfg['slurm']['n'] = 8
43
+ return cfg
44
+
45
+ def ddgan_cc12m_v6():
46
+ cfg = base()
47
+ cfg['model']['text_encoder'] = "google/t5-v1_1-large"
48
+ return cfg
49
+
50
+ def ddgan_cc12m_v7():
51
+ cfg = base()
52
+ cfg['model']['classifier_free_guidance_proba'] = 0.2
53
+ cfg['slurm']['N'] = 2
54
+ cfg['slurm']['n'] = 8
55
+ return cfg
56
+
57
+ def ddgan_cc12m_v8():
58
+ cfg = base()
59
+ cfg['model']['text_encoder'] = "google/t5-v1_1-large"
60
+ cfg['model']['classifier_free_guidance_proba'] = 0.2
61
+ return cfg
62
+
63
+ def ddgan_cc12m_v9():
64
+ cfg = base()
65
+ cfg['model']['text_encoder'] = "google/t5-v1_1-large"
66
+ cfg['model']['classifier_free_guidance_proba'] = 0.2
67
+ cfg['model']['num_channels_dae'] = 320
68
+ cfg['model']['image_size'] = 64
69
+ cfg['model']['batch_size'] = 1
70
+ return cfg
71
+
72
+
73
+ def ddgan_cc12m_v11():
74
+ cfg = base()
75
+ cfg['model']['text_encoder'] = "google/t5-v1_1-large"
76
+ cfg['model']['classifier_free_guidance_proba'] = 0.2
77
+ cfg['model']['cross_attention'] = ""
78
+ return cfg
79
+
80
+ models = [
81
+ ddgan_cc12m_v2,
82
+ ddgan_cc12m_v6,
83
+ ddgan_cc12m_v7,
84
+ ddgan_cc12m_v8,
85
+ ddgan_cc12m_v9,
86
+ ddgan_cc12m_v11,
87
+
88
+ ]
89
+ def get_model(model_name):
90
+ for model in models:
91
+ if model.__name__ == model_name:
92
+ return model()
93
+
94
+
95
+ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir=""):
96
+
97
+ cfg = get_model(model_name)
98
+ model = cfg['model']
99
+ if epoch is None:
100
+ paths = glob('./saved_info/dd_gan/{}/{}/netG_*.pth'.format(model["dataset"], model_name))
101
+ epoch = max(
102
+ [int(os.path.basename(path).replace(".pth", "").split("_")[1]) for path in paths]
103
+ )
104
+ args = {}
105
+ args['exp'] = model_name
106
+ args['image_size'] = model['image_size']
107
+ args['num_channels'] = model['num_channels']
108
+ args['dataset'] = model['dataset']
109
+ args['num_channels_dae'] = model['num_channels_dae']
110
+ args['ch_mult'] = model['ch_mult']
111
+ args['num_timesteps'] = model['num_timesteps']
112
+ args['num_res_blocks'] = model['num_res_blocks']
113
+ args['batch_size'] = model['batch_size'] if batch_size is None else batch_size
114
+ args['epoch'] = epoch
115
+ args['cond_text'] = f'"{cond_text}"'
116
+ args['text_encoder'] = model.get("text_encoder")
117
+ args['cross_attention'] = model.get("cross_attention")
118
+ args['guidance_scale'] = guidance_scale
119
+
120
+ if fid:
121
+ args['compute_fid'] = ''
122
+ args['real_img_dir'] = real_img_dir
123
+ cmd = "python test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
124
+ print(cmd)
125
+ call(cmd, shell=True)
126
+
127
+ run([test])
score_sde/models/discriminator.py CHANGED
@@ -96,11 +96,12 @@ class DownConvBlock(nn.Module):
96
  class Discriminator_small(nn.Module):
97
  """A time-dependent discriminator for small images (CIFAR10, StackMNIST)."""
98
 
99
- def __init__(self, nc = 3, ngf = 64, t_emb_dim = 128, act=nn.LeakyReLU(0.2)):
100
  super().__init__()
101
  # Gaussian random feature embedding layer for time
102
  self.act = act
103
-
 
104
 
105
  self.t_embed = TimestepEmbedding(
106
  embedding_dim=t_emb_dim,
@@ -131,10 +132,11 @@ class Discriminator_small(nn.Module):
131
  self.stddev_feat = 1
132
 
133
 
134
- def forward(self, x, t, x_t):
135
- t_embed = self.act(self.t_embed(t))
136
-
137
-
 
138
  input_x = torch.cat((x, x_t), dim = 1)
139
 
140
  h0 = self.start_conv(input_x)
@@ -159,10 +161,9 @@ class Discriminator_small(nn.Module):
159
 
160
  out = self.final_conv(out)
161
  out = self.act(out)
162
-
163
 
164
  out = out.view(out.shape[0], out.shape[1], -1).sum(2)
165
- out = self.end_linear(out)
166
 
167
  return out
168
 
@@ -170,9 +171,10 @@ class Discriminator_small(nn.Module):
170
  class Discriminator_large(nn.Module):
171
  """A time-dependent discriminator for large images (CelebA, LSUN)."""
172
 
173
- def __init__(self, nc = 1, ngf = 32, t_emb_dim = 128, act=nn.LeakyReLU(0.2)):
174
  super().__init__()
175
  # Gaussian random feature embedding layer for time
 
176
  self.act = act
177
 
178
  self.t_embed = TimestepEmbedding(
@@ -202,8 +204,9 @@ class Discriminator_large(nn.Module):
202
  self.stddev_feat = 1
203
 
204
 
205
- def forward(self, x, t, x_t):
206
- t_embed = self.act(self.t_embed(t))
 
207
 
208
  input_x = torch.cat((x, x_t), dim = 1)
209
 
@@ -233,7 +236,6 @@ class Discriminator_large(nn.Module):
233
  out = self.act(out)
234
 
235
  out = out.view(out.shape[0], out.shape[1], -1).sum(2)
236
- out = self.end_linear(out)
237
-
238
  return out
239
 
 
96
  class Discriminator_small(nn.Module):
97
  """A time-dependent discriminator for small images (CIFAR10, StackMNIST)."""
98
 
99
+ def __init__(self, nc = 3, ngf = 64, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768):
100
  super().__init__()
101
  # Gaussian random feature embedding layer for time
102
  self.act = act
103
+ self.cond_proj = nn.Linear(cond_size, ngf*8)
104
+ # self.cond_proj.weight.data = default_initializer()(self.cond_proj.weight.shape)
105
 
106
  self.t_embed = TimestepEmbedding(
107
  embedding_dim=t_emb_dim,
 
132
  self.stddev_feat = 1
133
 
134
 
135
+ def forward(self, x, t, x_t, cond=None):
136
+ t_embed = self.t_embed(t)
137
+ # if cond is not None:
138
+ # t_embed = t_embed + self.cond_proj(cond)
139
+ t_embed = self.act(t_embed)
140
  input_x = torch.cat((x, x_t), dim = 1)
141
 
142
  h0 = self.start_conv(input_x)
 
161
 
162
  out = self.final_conv(out)
163
  out = self.act(out)
 
164
 
165
  out = out.view(out.shape[0], out.shape[1], -1).sum(2)
166
+ out = self.end_linear(out) + (self.cond_proj(cond) * out).sum(dim=1, keepdim=True)
167
 
168
  return out
169
 
 
171
  class Discriminator_large(nn.Module):
172
  """A time-dependent discriminator for large images (CelebA, LSUN)."""
173
 
174
+ def __init__(self, nc = 1, ngf = 32, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768):
175
  super().__init__()
176
  # Gaussian random feature embedding layer for time
177
+ self.cond_proj = nn.Linear(cond_size, ngf*8)
178
  self.act = act
179
 
180
  self.t_embed = TimestepEmbedding(
 
204
  self.stddev_feat = 1
205
 
206
 
207
+ def forward(self, x, t, x_t, cond=None):
208
+ t_embed = self.t_embed(t)
209
+ t_embed = self.act(t_embed)
210
 
211
  input_x = torch.cat((x, x_t), dim = 1)
212
 
 
236
  out = self.act(out)
237
 
238
  out = out.view(out.shape[0], out.shape[1], -1).sum(2)
239
+ out = self.end_linear(out) + (self.cond_proj(cond) * out).sum(dim=1, keepdim=True)
 
240
  return out
241
 
score_sde/models/layers.py CHANGED
@@ -538,6 +538,276 @@ class AttnBlock(nn.Module):
538
  return x + h
539
 
540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  class Upsample(nn.Module):
542
  def __init__(self, channels, with_conv=False):
543
  super().__init__()
@@ -616,4 +886,4 @@ class ResnetBlockDDPM(nn.Module):
616
  x = self.Conv_2(x)
617
  else:
618
  x = self.NIN_0(x)
619
- return x + h
 
538
  return x + h
539
 
540
 
541
+ class CondAttnBlock(nn.Module):
542
+ """Channel-wise self-attention block."""
543
+ def __init__(self, channels, context_dim, dim_head=64, heads=8, norm_context=False, cosine_sim_attn=False):
544
+ super().__init__()
545
+ self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
546
+ self.ca = CrossAttention(
547
+ channels,
548
+ context_dim=context_dim,
549
+ dim_head=dim_head,
550
+ heads=heads,
551
+ norm_context=norm_context,
552
+ cosine_sim_attn=cosine_sim_attn,
553
+ )
554
+
555
+ def forward(self, x, cond, mask=None):
556
+ B, C, H, W = x.shape
557
+ h = self.GroupNorm_0(x)
558
+ h = h.view(B, C, H*W)
559
+ h = h.permute(0,2,1)
560
+ h = h.contiguous()
561
+ h_new = self.ca(h, cond, mask=mask)
562
+ h_new = h_new.permute(0,2,1)
563
+ h_new = h_new.contiguous()
564
+ h_new = h_new.view(B, C, H, W)
565
+ return x + h_new
566
+
567
+ from torch import einsum
568
+ from einops import rearrange, repeat, reduce
569
+ from einops.layers.torch import Rearrange, Reduce
570
+ from einops_exts import rearrange_many, repeat_many, check_shape
571
+ from einops_exts.torch import EinopsToAndFrom
572
+
573
+ def default(val, d):
574
+ if exists(val):
575
+ return val
576
+ return d() if callable(d) else d
577
+
578
+ class Identity(nn.Module):
579
+ def __init__(self, *args, **kwargs):
580
+ super().__init__()
581
+
582
+ def forward(self, x, *args, **kwargs):
583
+ return x
584
+
585
+
586
+ class CrossAttention(nn.Module):
587
+ def __init__(
588
+ self,
589
+ dim,
590
+ *,
591
+ context_dim = None,
592
+ dim_head = 64,
593
+ heads = 8,
594
+ norm_context = False,
595
+ cosine_sim_attn = False
596
+ ):
597
+ super().__init__()
598
+ self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
599
+ self.cosine_sim_attn = cosine_sim_attn
600
+ self.cosine_sim_scale = 16 if cosine_sim_attn else 1
601
+
602
+ self.heads = heads
603
+ inner_dim = dim_head * heads
604
+
605
+ context_dim = default(context_dim, dim)
606
+
607
+ self.norm = nn.LayerNorm(dim)
608
+ self.norm_context = nn.LayerNorm(context_dim) if norm_context else Identity()
609
+
610
+ self.null_kv = nn.Parameter(torch.randn(2, dim_head))
611
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
612
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
613
+
614
+ self.to_out = nn.Sequential(
615
+ nn.Linear(inner_dim, dim, bias = False),
616
+ nn.LayerNorm(dim)
617
+ )
618
+
619
+ def forward(self, x, context, mask = None):
620
+ b, n, device = *x.shape[:2], x.device
621
+
622
+ x = self.norm(x)
623
+ context = self.norm_context(context)
624
+
625
+ q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
626
+
627
+ q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
628
+
629
+ # add null key / value for classifier free guidance in prior net
630
+
631
+ nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
632
+
633
+ k = torch.cat((nk, k), dim = -2)
634
+ v = torch.cat((nv, v), dim = -2)
635
+
636
+ q = q * self.scale
637
+
638
+ # cosine sim attention
639
+
640
+ if self.cosine_sim_attn:
641
+ q, k = map(l2norm, (q, k))
642
+
643
+ # similarities
644
+
645
+ sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale
646
+
647
+ # masking
648
+
649
+ max_neg_value = -torch.finfo(sim.dtype).max
650
+
651
+ if exists(mask):
652
+ mask = F.pad(mask, (1, 0), value = True)
653
+ mask = rearrange(mask, 'b j -> b 1 1 j')
654
+ sim = sim.masked_fill(~mask, max_neg_value)
655
+
656
+ attn = sim.softmax(dim = -1, dtype = torch.float32)
657
+
658
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
659
+ out = rearrange(out, 'b h n d -> b n (h d)')
660
+ return self.to_out(out)
661
+
662
+
663
+ class PerceiverAttention(nn.Module):
664
+ def __init__(
665
+ self,
666
+ *,
667
+ dim,
668
+ dim_head = 64,
669
+ heads = 8,
670
+ cosine_sim_attn = False
671
+ ):
672
+ super().__init__()
673
+ self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1
674
+ self.cosine_sim_attn = cosine_sim_attn
675
+ self.cosine_sim_scale = 16 if cosine_sim_attn else 1
676
+
677
+ self.heads = heads
678
+ inner_dim = dim_head * heads
679
+
680
+ self.norm = nn.LayerNorm(dim)
681
+ self.norm_latents = nn.LayerNorm(dim)
682
+
683
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
684
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
685
+
686
+ self.to_out = nn.Sequential(
687
+ nn.Linear(inner_dim, dim, bias = False),
688
+ nn.LayerNorm(dim)
689
+ )
690
+
691
+ def forward(self, x, latents, mask = None):
692
+ x = self.norm(x)
693
+ latents = self.norm_latents(latents)
694
+
695
+ b, h = x.shape[0], self.heads
696
+
697
+ q = self.to_q(latents)
698
+
699
+ # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
700
+ kv_input = torch.cat((x, latents), dim = -2)
701
+ k, v = self.to_kv(kv_input).chunk(2, dim = -1)
702
+
703
+ q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
704
+
705
+ q = q * self.scale
706
+
707
+ # cosine sim attention
708
+
709
+ if self.cosine_sim_attn:
710
+ q, k = map(l2norm, (q, k))
711
+
712
+ # similarities and masking
713
+
714
+ sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale
715
+
716
+ if exists(mask):
717
+ max_neg_value = -torch.finfo(sim.dtype).max
718
+ mask = F.pad(mask, (0, latents.shape[-2]), value = True)
719
+ mask = rearrange(mask, 'b j -> b 1 1 j')
720
+ sim = sim.masked_fill(~mask, max_neg_value)
721
+
722
+ # attention
723
+
724
+ attn = sim.softmax(dim = -1)
725
+
726
+ out = einsum('... i j, ... j d -> ... i d', attn, v)
727
+ out = rearrange(out, 'b h n d -> b n (h d)', h = h)
728
+ return self.to_out(out)
729
+
730
+
731
+ def FeedForward(dim, mult = 2):
732
+ hidden_dim = int(dim * mult)
733
+ return nn.Sequential(
734
+ nn.LayerNorm(dim),
735
+ nn.Linear(dim, hidden_dim, bias = False),
736
+ nn.GELU(),
737
+ nn.LayerNorm(hidden_dim),
738
+ nn.Linear(hidden_dim, dim, bias = False)
739
+ )
740
+
741
+ def exists(val):
742
+ return val is not None
743
+
744
+
745
+ def masked_mean(t, *, dim, mask = None):
746
+ if not exists(mask):
747
+ return t.mean(dim = dim)
748
+ denom = mask.sum(dim = dim, keepdim = True)
749
+ mask = rearrange(mask, 'b n -> b n 1')
750
+ masked_t = t.masked_fill(~mask, 0.)
751
+
752
+ return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
753
+
754
+ class PerceiverResampler(nn.Module):
755
+ def __init__(
756
+ self,
757
+ *,
758
+ dim,
759
+ depth,
760
+ dim_head = 64,
761
+ heads = 8,
762
+ num_latents = 64,
763
+ num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
764
+ max_seq_len = 512,
765
+ ff_mult = 4,
766
+ cosine_sim_attn = False
767
+ ):
768
+ super().__init__()
769
+ self.pos_emb = nn.Embedding(max_seq_len, dim)
770
+
771
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
772
+
773
+ self.to_latents_from_mean_pooled_seq = None
774
+
775
+ if num_latents_mean_pooled > 0:
776
+ self.to_latents_from_mean_pooled_seq = nn.Sequential(
777
+ nn.LayerNorm(dim),
778
+ nn.Linear(dim, dim * num_latents_mean_pooled),
779
+ Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
780
+ )
781
+
782
+ self.layers = nn.ModuleList([])
783
+ for _ in range(depth):
784
+ self.layers.append(nn.ModuleList([
785
+ PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn),
786
+ FeedForward(dim = dim, mult = ff_mult)
787
+ ]))
788
+
789
+ def forward(self, x, mask = None):
790
+ n, device = x.shape[1], x.device
791
+ pos_emb = self.pos_emb(torch.arange(n, device = device))
792
+
793
+ x_with_pos = x + pos_emb
794
+
795
+ latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
796
+
797
+ if exists(self.to_latents_from_mean_pooled_seq):
798
+ meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
799
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
800
+ latents = torch.cat((meanpooled_latents, latents), dim = -2)
801
+
802
+ for attn, ff in self.layers:
803
+ latents = attn(x_with_pos, latents, mask = mask) + latents
804
+ latents = ff(latents) + latents
805
+
806
+ return latents
807
+
808
+
809
+
810
+
811
  class Upsample(nn.Module):
812
  def __init__(self, channels, with_conv=False):
813
  super().__init__()
 
886
  x = self.Conv_2(x)
887
  else:
888
  x = self.NIN_0(x)
889
+ return x + h
score_sde/models/ncsnpp_generator_adagn.py CHANGED
@@ -66,8 +66,10 @@ class NCSNpp(nn.Module):
66
  self.not_use_tanh = config.not_use_tanh
67
  self.act = act = nn.SiLU()
68
  self.z_emb_dim = z_emb_dim = config.z_emb_dim
69
-
70
  self.nf = nf = config.num_channels_dae
 
 
 
71
  ch_mult = config.ch_mult
72
  self.num_res_blocks = num_res_blocks = config.num_res_blocks
73
  self.attn_resolutions = attn_resolutions = config.attn_resolutions
@@ -115,10 +117,12 @@ class NCSNpp(nn.Module):
115
  modules.append(nn.Linear(nf * 4, nf * 4))
116
  modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
117
  nn.init.zeros_(modules[-1].bias)
118
-
119
- AttnBlock = functools.partial(layerspp.AttnBlockpp,
120
- init_scale=init_scale,
121
- skip_rescale=skip_rescale)
 
 
122
 
123
  Upsample = functools.partial(layerspp.Upsample,
124
  with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
@@ -277,7 +281,7 @@ class NCSNpp(nn.Module):
277
  self.z_transform = nn.Sequential(*mapping_layers)
278
 
279
 
280
- def forward(self, x, time_cond, z):
281
  # timestep/noise_level embedding; only for continuous training
282
  zemb = self.z_transform(z)
283
  modules = self.all_modules
@@ -296,9 +300,14 @@ class NCSNpp(nn.Module):
296
 
297
  else:
298
  raise ValueError(f'embedding type {self.embedding_type} unknown.')
299
-
 
 
 
300
  if self.conditional:
301
  temb = modules[m_idx](temb)
 
 
302
  m_idx += 1
303
  temb = modules[m_idx](self.act(temb))
304
  m_idx += 1
@@ -322,7 +331,10 @@ class NCSNpp(nn.Module):
322
  h = modules[m_idx](hs[-1], temb, zemb)
323
  m_idx += 1
324
  if h.shape[-1] in self.attn_resolutions:
325
- h = modules[m_idx](h)
 
 
 
326
  m_idx += 1
327
 
328
  hs.append(h)
@@ -354,7 +366,10 @@ class NCSNpp(nn.Module):
354
  h = hs[-1]
355
  h = modules[m_idx](h, temb, zemb)
356
  m_idx += 1
357
- h = modules[m_idx](h)
 
 
 
358
  m_idx += 1
359
  h = modules[m_idx](h, temb, zemb)
360
  m_idx += 1
@@ -368,7 +383,10 @@ class NCSNpp(nn.Module):
368
  m_idx += 1
369
 
370
  if h.shape[-1] in self.attn_resolutions:
371
- h = modules[m_idx](h)
 
 
 
372
  m_idx += 1
373
 
374
  if self.progressive != 'none':
@@ -429,3 +447,4 @@ class NCSNpp(nn.Module):
429
  return torch.tanh(h)
430
  else:
431
  return h
 
 
66
  self.not_use_tanh = config.not_use_tanh
67
  self.act = act = nn.SiLU()
68
  self.z_emb_dim = z_emb_dim = config.z_emb_dim
 
69
  self.nf = nf = config.num_channels_dae
70
+ self.cond_proj = nn.Linear(config.cond_size, self.nf*4)
71
+ self.cond_proj.weight.data = default_initializer()(self.cond_proj.weight.shape)
72
+
73
  ch_mult = config.ch_mult
74
  self.num_res_blocks = num_res_blocks = config.num_res_blocks
75
  self.attn_resolutions = attn_resolutions = config.attn_resolutions
 
117
  modules.append(nn.Linear(nf * 4, nf * 4))
118
  modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
119
  nn.init.zeros_(modules[-1].bias)
120
+ if config.cross_attention:
121
+ AttnBlock = functools.partial(layers.CondAttnBlock, context_dim=config.cond_size)
122
+ else:
123
+ AttnBlock = functools.partial(layerspp.AttnBlockpp,
124
+ init_scale=init_scale,
125
+ skip_rescale=skip_rescale)
126
 
127
  Upsample = functools.partial(layerspp.Upsample,
128
  with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
 
281
  self.z_transform = nn.Sequential(*mapping_layers)
282
 
283
 
284
+ def forward(self, x, time_cond, z, cond=None):
285
  # timestep/noise_level embedding; only for continuous training
286
  zemb = self.z_transform(z)
287
  modules = self.all_modules
 
300
 
301
  else:
302
  raise ValueError(f'embedding type {self.embedding_type} unknown.')
303
+
304
+ if cond is not None:
305
+ cond_pooled, cond, cond_mask = cond
306
+
307
  if self.conditional:
308
  temb = modules[m_idx](temb)
309
+ if cond is not None:
310
+ temb = temb + self.cond_proj(cond_pooled)
311
  m_idx += 1
312
  temb = modules[m_idx](self.act(temb))
313
  m_idx += 1
 
331
  h = modules[m_idx](hs[-1], temb, zemb)
332
  m_idx += 1
333
  if h.shape[-1] in self.attn_resolutions:
334
+ if type(modules[m_idx]) == layers.CondAttnBlock:
335
+ h = modules[m_idx](h, cond, cond_mask)
336
+ else:
337
+ h = modules[m_idx](h)
338
  m_idx += 1
339
 
340
  hs.append(h)
 
366
  h = hs[-1]
367
  h = modules[m_idx](h, temb, zemb)
368
  m_idx += 1
369
+ if type(modules[m_idx]) == layers.CondAttnBlock:
370
+ h = modules[m_idx](h, cond, cond_mask)
371
+ else:
372
+ h = modules[m_idx](h)
373
  m_idx += 1
374
  h = modules[m_idx](h, temb, zemb)
375
  m_idx += 1
 
383
  m_idx += 1
384
 
385
  if h.shape[-1] in self.attn_resolutions:
386
+ if type(modules[m_idx]) == layers.CondAttnBlock:
387
+ h = modules[m_idx](h, cond, cond_mask)
388
+ else:
389
+ h = modules[m_idx](h)
390
  m_idx += 1
391
 
392
  if self.progressive != 'none':
 
447
  return torch.tanh(h)
448
  else:
449
  return h
450
+
scripts/fid.sh ADDED
File without changes
scripts/init.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ machine=$(cat /etc/FZJ/systemname)
2
+ if [[ "$machine" == jurecadc ]]; then
3
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
4
+ ml purge
5
+ ml use $OTHERSTAGES
6
+ ml Stages/2022
7
+ ml GCC/11.2.0
8
+ ml OpenMPI/4.1.2
9
+ ml CUDA/11.5
10
+ ml cuDNN/8.3.1.22-CUDA-11.5
11
+ ml NCCL/2.12.7-1-CUDA-11.5
12
+ ml PyTorch/1.11-CUDA-11.5
13
+ ml Horovod/0.24
14
+ ml torchvision/0.12.0
15
+ source /p/project/covidnetx/environments/jureca_2022/bin/activate
16
+ fi
17
+ if [[ "$machine" == juwelsbooster ]]; then
18
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
19
+ ml purge
20
+ ml use $OTHERSTAGES
21
+ ml Stages/2022
22
+ ml GCC/11.2.0
23
+ ml OpenMPI/4.1.2
24
+ ml CUDA/11.5
25
+ ml cuDNN/8.3.1.22-CUDA-11.5
26
+ ml NCCL/2.12.7-1-CUDA-11.5
27
+ ml PyTorch/1.11-CUDA-11.5
28
+ ml Horovod/0.24
29
+ ml torchvision/0.12.0
30
+ source /p/project/covidnetx/environments/juwels_booster_2022/bin/activate
31
+ fi
32
+ if [[ "$machine" == jusuf ]]; then
33
+ echo not supported
34
+ fi
scripts/init_2020.sh ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ machine=$(cat /etc/FZJ/systemname)
2
+ if [[ "$machine" == jurecadc ]]; then
3
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
4
+ #ml use $OTHERSTAGES
5
+ #ml Stages/2020
6
+ #ml GCC/9.3.0
7
+ #ml OpenMPI/4.1.0rc1
8
+ #ml CUDA/11.0
9
+ #ml cuDNN/8.0.2.39-CUDA-11.0
10
+ #ml NCCL/2.8.3-1-CUDA-11.0
11
+ #ml PyTorch
12
+ #ml Horovod/0.20.3-Python-3.8.5
13
+ #ml scikit
14
+ #source /p/project/covidnetx/environments/jureca/bin/activate
15
+ ml purge
16
+ ml use $OTHERSTAGES
17
+ ml Stages/2020
18
+ ml GCC/10.3.0
19
+ ml OpenMPI/4.1.1
20
+ ml Horovod/0.23.0-Python-3.8.5
21
+ ml scikit
22
+ source /p/project/covidnetx/environments/jureca/bin/activate
23
+ fi
24
+ if [[ "$machine" == juwelsbooster ]]; then
25
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
26
+ #ml use $OTHERSTAGES
27
+ #ml Stages/2020
28
+ #ml GCC/9.3.0
29
+ #ml OpenMPI/4.1.0rc1
30
+ #ml CUDA/11.0
31
+ #ml cuDNN/8.0.2.39-CUDA-11.0
32
+ #ml NCCL/2.8.3-1-CUDA-11.0
33
+ #ml PyTorch
34
+ #ml Horovod/0.20.3-Python-3.8.5
35
+ #ml scikit
36
+
37
+ #ml Stages/2021
38
+ #ml GCC
39
+ #ml OpenMPI
40
+ #ml CUDA
41
+ #ml cuDNN
42
+ #ml NCCL
43
+ #ml PyTorch
44
+ #ml Horovod
45
+ #ml scikit
46
+
47
+ ml purge
48
+ ml use $OTHERSTAGES
49
+ ml Stages/2020
50
+ ml GCC/10.3.0
51
+ ml OpenMPI/4.1.1
52
+ ml Horovod/0.23.0-Python-3.8.5
53
+ ml scikit
54
+ source /p/project/covidnetx/environments/juwels_booster/bin/activate
55
+ fi
56
+ if [[ "$machine" == jusuf ]]; then
57
+ ml purge
58
+ ml use $OTHERSTAGES
59
+ ml Stages/2020
60
+ ml GCC/9.3.0
61
+ ml OpenMPI/4.1.0rc1
62
+ ml CUDA/11.0
63
+ ml cuDNN/8.0.2.39-CUDA-11.0
64
+ ml NCCL/2.8.3-1-CUDA-11.0
65
+ ml PyTorch
66
+ ml Horovod/0.20.3-Python-3.8.5
67
+ #ml scikit
68
+ source /p/project/covidnetx/environments/jusuf/bin/activate
69
+ fi
scripts/init_2022.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ machine=$(cat /etc/FZJ/systemname)
2
+ if [[ "$machine" == jurecadc ]]; then
3
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
4
+ ml purge
5
+ ml use $OTHERSTAGES
6
+ ml Stages/2022
7
+ ml GCC/11.2.0
8
+ ml OpenMPI/4.1.2
9
+ ml CUDA/11.5
10
+ ml cuDNN/8.3.1.22-CUDA-11.5
11
+ ml NCCL/2.12.7-1-CUDA-11.5
12
+ ml PyTorch/1.11-CUDA-11.5
13
+ ml Horovod/0.24
14
+ ml torchvision/0.12.0
15
+ source /p/project/covidnetx/environments/jureca_2022/bin/activate
16
+ fi
17
+ if [[ "$machine" == juwelsbooster ]]; then
18
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
19
+ ml purge
20
+ ml use $OTHERSTAGES
21
+ ml Stages/2022
22
+ ml GCC/11.2.0
23
+ ml OpenMPI/4.1.2
24
+ ml CUDA/11.5
25
+ ml cuDNN/8.3.1.22-CUDA-11.5
26
+ ml NCCL/2.12.7-1-CUDA-11.5
27
+ ml PyTorch/1.11-CUDA-11.5
28
+ ml Horovod/0.24
29
+ ml torchvision/0.12.0
30
+ source /p/project/covidnetx/environments/juwels_booster_2022/bin/activate
31
+ fi
32
+ if [[ "$machine" == jusuf ]]; then
33
+ echo not supported
34
+ fi
scripts/run_jurecadc_ddp.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -x
2
+ #SBATCH --account=zam
3
+ #SBATCH --nodes=1
4
+ #SBATCH --ntasks-per-node=4
5
+ #SBATCH --cpus-per-task=24
6
+ #SBATCH --time=06:00:00
7
+ #SBATCH --gres=gpu:4
8
+ #SBATCH --partition=dc-gpu
9
+ source set_torch_distributed_vars.sh
10
+ #source scripts/init_2022.sh
11
+ #source scripts/init_2020.sh
12
+ source scripts/init.sh
13
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
14
+ echo "Job id: $SLURM_JOB_ID"
15
+ export TOKENIZERS_PARALLELISM=false
16
+ export NCCL_ASYNC_ERROR_HANDLING=1
17
+ srun python -u $*
scripts/run_jusuf_ddp.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -x
2
+ #SBATCH --account=zam
3
+ #SBATCH --nodes=1
4
+ #SBATCH --ntasks-per-node=1
5
+ #SBATCH --cpus-per-task=24
6
+ #SBATCH --time=06:00:00
7
+ #SBATCH --gres=gpu:1
8
+ #SBATCH --partition=gpus
9
+ source set_torch_distributed_vars.sh
10
+ source scripts/init.sh
11
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
12
+ echo "Job id: $SLURM_JOB_ID"
13
+ export TOKENIZERS_PARALLELISM=false
14
+ srun python -u $*
scripts/run_juwelsbooster_ddp.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash -x
2
+ #SBATCH --account=covidnetx
3
+ #SBATCH --nodes=4
4
+ #SBATCH --ntasks-per-node=4
5
+ #SBATCH --cpus-per-task=24
6
+ #SBATCH --time=06:00:00
7
+ #SBATCH --gres=gpu:4
8
+ #SBATCH --partition=booster
9
+ source set_torch_distributed_vars.sh
10
+ #source scripts/init_2022.sh
11
+ #source scripts/init_2020.sh
12
+ source scripts/init.sh
13
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
14
+ echo "Job id: $SLURM_JOB_ID"
15
+ export TOKENIZERS_PARALLELISM=false
16
+ export NCCL_ASYNC_ERROR_HANDLING=1
17
+ srun python -u $*
t5.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from transformers import T5Tokenizer, T5EncoderModel, T5Config
4
+
5
+ transformers.logging.set_verbosity_error()
6
+
7
+ def exists(val):
8
+ return val is not None
9
+
10
+ # config
11
+
12
+ MAX_LENGTH = 256
13
+
14
+ DEFAULT_T5_NAME = 'google/t5-v1_1-base'
15
+
16
+ T5_CONFIGS = {}
17
+
18
+ # singleton globals
19
+
20
+ def get_tokenizer(name):
21
+ tokenizer = T5Tokenizer.from_pretrained(name)
22
+ return tokenizer
23
+
24
+ def get_model(name):
25
+ model = T5EncoderModel.from_pretrained(name)
26
+ return model
27
+
28
+ def get_model_and_tokenizer(name):
29
+ global T5_CONFIGS
30
+
31
+ if name not in T5_CONFIGS:
32
+ T5_CONFIGS[name] = dict()
33
+ if "model" not in T5_CONFIGS[name]:
34
+ T5_CONFIGS[name]["model"] = get_model(name)
35
+ if "tokenizer" not in T5_CONFIGS[name]:
36
+ T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
37
+
38
+ return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
39
+
40
+ def get_encoded_dim(name):
41
+ if name not in T5_CONFIGS:
42
+ # avoids loading the model if we only want to get the dim
43
+ config = T5Config.from_pretrained(name)
44
+ T5_CONFIGS[name] = dict(config=config)
45
+ elif "config" in T5_CONFIGS[name]:
46
+ config = T5_CONFIGS[name]["config"]
47
+ elif "model" in T5_CONFIGS[name]:
48
+ config = T5_CONFIGS[name]["model"].config
49
+ else:
50
+ assert False
51
+ return config.d_model
52
+
53
+ class T5Encoder(torch.nn.Module):
54
+
55
+ def __init__(self, name=DEFAULT_T5_NAME, max_length=MAX_LENGTH, padding='longest', masked_mean=False):
56
+ super().__init__()
57
+ self.name = name
58
+ self.t5, self.tokenizer = get_model_and_tokenizer(name)
59
+ self.max_length = max_length
60
+ self.output_size = get_encoded_dim(name)
61
+ self.padding = padding
62
+ self.masked_mean = masked_mean
63
+
64
+ def forward(self, x, return_only_pooled=True):
65
+ encoded = self.tokenizer.batch_encode_plus(
66
+ x,
67
+ return_tensors = "pt",
68
+ padding = self.padding,
69
+ max_length = self.max_length,
70
+ truncation = True
71
+ )
72
+ device = next(self.t5.parameters()).device
73
+ input_ids = encoded.input_ids.to(device)
74
+ attn_mask = encoded.attention_mask.to(device).bool()
75
+ output = self.t5(input_ids = input_ids, attention_mask = attn_mask)
76
+ encoded_text = output.last_hidden_state.detach()
77
+ # return encoded_text[:, 0]
78
+ # print(input_ids)
79
+ # print(attn_mask)
80
+ #if self.masked_mean:
81
+ pooled = masked_mean(encoded_text, dim=1, mask=attn_mask)
82
+ if return_only_pooled:
83
+ return pooled
84
+ else:
85
+ return pooled, encoded_text, attn_mask
86
+ #else:
87
+ # return encoded_text.mean(dim=1)
88
+
89
+
90
+ from einops import rearrange
91
+ def masked_mean(t, *, dim, mask = None):
92
+ if not exists(mask):
93
+ return t.mean(dim = dim)
94
+
95
+ denom = mask.sum(dim = dim, keepdim = True)
96
+ mask = rearrange(mask, 'b n -> b n 1')
97
+ masked_t = t.masked_fill(~mask, 0.)
98
+
99
+ return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
test_ddgan.py CHANGED
@@ -7,12 +7,12 @@
7
  import argparse
8
  import torch
9
  import numpy as np
10
-
11
  import os
12
-
13
  import torchvision
14
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
15
- from pytorch_fid.fid_score import calculate_fid_given_paths
16
 
17
  #%% Diffusion coefficients
18
  def var_func_vp(t, beta_min, beta_max):
@@ -112,7 +112,7 @@ def sample_posterior(coefficients, x_0,x_t, t):
112
 
113
  return sample_x_pos
114
 
115
- def sample_from_model(coefficients, generator, n_time, x_init, T, opt):
116
  x = x_init
117
  with torch.no_grad():
118
  for i in reversed(range(n_time)):
@@ -120,17 +120,70 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt):
120
 
121
  t_time = t
122
  latent_z = torch.randn(x.size(0), opt.nz, device=x.device)#.to(x.device)
123
- x_0 = generator(x, t_time, latent_z)
124
  x_new = sample_posterior(coefficients, x_0, x, t)
125
  x = x_new.detach()
126
 
127
  return x
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  #%%
130
  def sample_and_test(args):
131
- torch.manual_seed(42)
132
  device = 'cuda:0'
133
-
 
 
 
134
  if args.dataset == 'cifar10':
135
  real_img_dir = 'pytorch_fid/cifar10_train_stat.npy'
136
  elif args.dataset == 'celeba_256':
@@ -157,7 +210,6 @@ def sample_and_test(args):
157
 
158
  pos_coeff = Posterior_Coefficients(args, device)
159
 
160
- iters_needed = 50000 //args.batch_size
161
 
162
  save_dir = "./generated_samples/{}".format(args.dataset)
163
 
@@ -165,25 +217,90 @@ def sample_and_test(args):
165
  os.makedirs(save_dir)
166
 
167
  if args.compute_fid:
168
- for i in range(iters_needed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  with torch.no_grad():
170
- x_t_1 = torch.randn(args.batch_size, args.num_channels,args.image_size, args.image_size).to(device)
171
- fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args)
172
-
 
 
 
 
 
173
  fake_sample = to_range_0_1(fake_sample)
 
174
  for j, x in enumerate(fake_sample):
175
  index = i * args.batch_size + j
176
  torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
177
- print('generating batch ', i)
178
-
179
- paths = [save_dir, real_img_dir]
180
-
181
- kwargs = {'batch_size': 100, 'device': device, 'dims': 2048}
182
- fid = calculate_fid_given_paths(paths=paths, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  print('FID = {}'.format(fid))
184
  else:
 
185
  x_t_1 = torch.randn(args.batch_size, args.num_channels,args.image_size, args.image_size).to(device)
186
- fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args)
 
 
 
187
  fake_sample = to_range_0_1(fake_sample)
188
  torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
189
 
@@ -198,6 +315,13 @@ if __name__ == '__main__':
198
  parser.add_argument('--compute_fid', action='store_true', default=False,
199
  help='whether or not compute FID')
200
  parser.add_argument('--epoch_id', type=int,default=1000)
 
 
 
 
 
 
 
201
  parser.add_argument('--num_channels', type=int, default=3,
202
  help='channel of image')
203
  parser.add_argument('--centered', action='store_false', default=True,
@@ -262,6 +386,8 @@ if __name__ == '__main__':
262
  parser.add_argument('--z_emb_dim', type=int, default=256)
263
  parser.add_argument('--t_emb_dim', type=int, default=256)
264
  parser.add_argument('--batch_size', type=int, default=200, help='sample generating batch size')
 
 
265
 
266
 
267
 
@@ -272,4 +398,4 @@ if __name__ == '__main__':
272
  sample_and_test(args)
273
 
274
 
275
-
 
7
  import argparse
8
  import torch
9
  import numpy as np
10
+ import time
11
  import os
12
+ import json
13
  import torchvision
14
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
15
+ import t5
16
 
17
  #%% Diffusion coefficients
18
  def var_func_vp(t, beta_min, beta_max):
 
112
 
113
  return sample_x_pos
114
 
115
+ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None):
116
  x = x_init
117
  with torch.no_grad():
118
  for i in reversed(range(n_time)):
 
120
 
121
  t_time = t
122
  latent_z = torch.randn(x.size(0), opt.nz, device=x.device)#.to(x.device)
123
+ x_0 = generator(x, t_time, latent_z, cond=cond)
124
  x_new = sample_posterior(coefficients, x_0, x, t)
125
  x = x_new.detach()
126
 
127
  return x
128
 
129
+
130
+ def sample_from_model_classifier_free_guidance(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0):
131
+ x = x_init
132
+ null = text_encoder([""] * len(x_init), return_only_pooled=False)
133
+ with torch.no_grad():
134
+ for i in reversed(range(n_time)):
135
+ t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)
136
+
137
+ t_time = t
138
+ latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
139
+
140
+ x_0_uncond = generator(x, t_time, latent_z, cond=null)
141
+ x_0_cond = generator(x, t_time, latent_z, cond=cond)
142
+
143
+ eps_uncond = (x - torch.sqrt(coefficients.alphas_cumprod[i]) * x_0_uncond) / torch.sqrt(1 - coefficients.alphas_cumprod[i])
144
+ eps_cond = (x - torch.sqrt(coefficients.alphas_cumprod[i]) * x_0_cond) / torch.sqrt(1 - coefficients.alphas_cumprod[i])
145
+
146
+ # eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
147
+ eps = eps_uncond * (1 - guidance_scale) + eps_cond * guidance_scale
148
+ x_0 = (1/torch.sqrt(coefficients.alphas_cumprod[i])) * (x - torch.sqrt(1 - coefficients.alphas_cumprod[i]) * eps)
149
+
150
+
151
+ # Dynamic thresholding
152
+ q = args.dynamic_thresholding_percentile
153
+ print("Before", x_0.min(), x_0.max())
154
+ if q:
155
+ shape = x_0.shape
156
+ x_0_v = x_0.view(shape[0], -1)
157
+ d = torch.quantile(torch.abs(x_0_v), q, dim=1, keepdim=True)
158
+ d.clamp_(min=1)
159
+ x_0_v = x_0_v.clamp(-d, d) / d
160
+ x_0 = x_0_v.view(shape)
161
+ print("After", x_0.min(), x_0.max())
162
+
163
+ x_new = sample_posterior(coefficients, x_0, x, t)
164
+
165
+ # Dynamic thresholding
166
+ # q = args.dynamic_thresholding_percentile
167
+ # shape = x_new.shape
168
+ # x_new_v = x_new.view(shape[0], -1)
169
+ # d = torch.quantile(torch.abs(x_new_v), q, dim=1, keepdim=True)
170
+ # d = torch.maximum(d, torch.ones_like(d))
171
+ # d.clamp_(min = 1.)
172
+ # x_new_v = torch.clamp(x_new_v, -d, d) / d
173
+ # x_new = x_new_v.view(shape)
174
+ x = x_new.detach()
175
+
176
+ return x
177
+
178
+
179
  #%%
180
  def sample_and_test(args):
181
+ torch.manual_seed(args.seed)
182
  device = 'cuda:0'
183
+ text_encoder = t5.T5Encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
184
+ args.cond_size = text_encoder.output_size
185
+ # cond = text_encoder([str(yi%10) for yi in range(args.batch_size)])
186
+
187
  if args.dataset == 'cifar10':
188
  real_img_dir = 'pytorch_fid/cifar10_train_stat.npy'
189
  elif args.dataset == 'celeba_256':
 
210
 
211
  pos_coeff = Posterior_Coefficients(args, device)
212
 
 
213
 
214
  save_dir = "./generated_samples/{}".format(args.dataset)
215
 
 
217
  os.makedirs(save_dir)
218
 
219
  if args.compute_fid:
220
+ from torch.nn.functional import adaptive_avg_pool2d
221
+ from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
222
+ from pytorch_fid.inception import InceptionV3
223
+
224
+ texts = open(args.cond_text).readlines()
225
+ #iters_needed = len(texts) // args.batch_size
226
+ #texts = list(map(lambda s:s.strip(), texts))
227
+ #ntimes = max(30000 // len(texts), 1)
228
+ #texts = texts * ntimes
229
+ print("Text size:", len(texts))
230
+ #print("Iters:", iters_needed)
231
+ i = 0
232
+ dims = 2048
233
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
234
+ inceptionv3 = InceptionV3([block_idx]).to(device)
235
+
236
+ if not args.real_img_dir.endswith("npz"):
237
+ real_mu, real_sigma = compute_statistics_of_path(
238
+ args.real_img_dir, inceptionv3, args.batch_size, dims, device,
239
+ resize=args.image_size,
240
+ )
241
+ np.savez("inception_statistics.npz", mu=real_mu, sigma=real_sigma)
242
+ else:
243
+ stats = np.load(args.real_img_dir)
244
+ real_mu = stats['mu']
245
+ real_sigma = stats['sigma']
246
+
247
+ fake_features = []
248
+ for b in range(0, len(texts), args.batch_size):
249
+ text = texts[b:b+args.batch_size]
250
  with torch.no_grad():
251
+ cond = text_encoder(text, return_only_pooled=False)
252
+ bs = len(text)
253
+ t0 = time.time()
254
+ x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
255
+ if args.guidance_scale:
256
+ fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
257
+ else:
258
+ fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
259
  fake_sample = to_range_0_1(fake_sample)
260
+ """
261
  for j, x in enumerate(fake_sample):
262
  index = i * args.batch_size + j
263
  torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
264
+ """
265
+ with torch.no_grad():
266
+ pred = inceptionv3(fake_sample)[0]
267
+ # If model output is not scalar, apply global spatial average pooling.
268
+ # This happens if you choose a dimensionality not equal 2048.
269
+ if pred.size(2) != 1 or pred.size(3) != 1:
270
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
271
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
272
+ fake_features.append(pred)
273
+ if i % 10 == 0:
274
+ print('generating batch ', i, time.time() - t0)
275
+ """
276
+ if i % 10 == 0:
277
+ ff = np.concatenate(fake_features)
278
+ fake_mu = np.mean(ff, axis=0)
279
+ fake_sigma = np.cov(ff, rowvar=False)
280
+ fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
281
+ print("FID", fid)
282
+ """
283
+ i += 1
284
+
285
+ fake_features = np.concatenate(fake_features)
286
+ fake_mu = np.mean(fake_features, axis=0)
287
+ fake_sigma = np.cov(fake_features, rowvar=False)
288
+ fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
289
+ dest = './saved_info/dd_gan/{}/{}/fid_{}.json'.format(args.dataset, args.exp, args.epoch_id)
290
+ results = {
291
+ "fid": fid,
292
+ }
293
+ results.update(vars(args))
294
+ with open(dest, "w") as fd:
295
+ json.dump(results, fd)
296
  print('FID = {}'.format(fid))
297
  else:
298
+ cond = text_encoder([args.cond_text] * args.batch_size, return_only_pooled=False)
299
  x_t_1 = torch.randn(args.batch_size, args.num_channels,args.image_size, args.image_size).to(device)
300
+ if args.guidance_scale:
301
+ fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
302
+ else:
303
+ fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
304
  fake_sample = to_range_0_1(fake_sample)
305
  torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
306
 
 
315
  parser.add_argument('--compute_fid', action='store_true', default=False,
316
  help='whether or not compute FID')
317
  parser.add_argument('--epoch_id', type=int,default=1000)
318
+ parser.add_argument('--guidance_scale', type=float,default=0)
319
+ parser.add_argument('--dynamic_thresholding_percentile', type=float,default=0)
320
+ parser.add_argument('--cond_text', type=str,default="0")
321
+
322
+ parser.add_argument('--cross_attention', action='store_true',default=False)
323
+
324
+
325
  parser.add_argument('--num_channels', type=int, default=3,
326
  help='channel of image')
327
  parser.add_argument('--centered', action='store_false', default=True,
 
386
  parser.add_argument('--z_emb_dim', type=int, default=256)
387
  parser.add_argument('--t_emb_dim', type=int, default=256)
388
  parser.add_argument('--batch_size', type=int, default=200, help='sample generating batch size')
389
+ parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
390
+ parser.add_argument('--masked_mean', action='store_true',default=False)
391
 
392
 
393
 
 
398
  sample_and_test(args)
399
 
400
 
401
+
train_ddgan.py CHANGED
@@ -18,7 +18,7 @@ import torch.optim as optim
18
  import torchvision
19
 
20
  import torchvision.transforms as transforms
21
- from torchvision.datasets import CIFAR10
22
  from datasets_prep.lsun import LSUN
23
  from datasets_prep.stackmnist_data import StackedMNIST, _data_transforms_stacked_mnist
24
  from datasets_prep.lmdb_datasets import LMDBDataset
@@ -27,6 +27,11 @@ from datasets_prep.lmdb_datasets import LMDBDataset
27
  from torch.multiprocessing import Process
28
  import torch.distributed as dist
29
  import shutil
 
 
 
 
 
30
 
31
  def copy_source(file, output_dir):
32
  shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))
@@ -172,7 +177,7 @@ def sample_posterior(coefficients, x_0,x_t, t):
172
 
173
  return sample_x_pos
174
 
175
- def sample_from_model(coefficients, generator, n_time, x_init, T, opt):
176
  x = x_init
177
  with torch.no_grad():
178
  for i in reversed(range(n_time)):
@@ -180,13 +185,15 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt):
180
 
181
  t_time = t
182
  latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
183
- x_0 = generator(x, t_time, latent_z)
184
  x_new = sample_posterior(coefficients, x_0, x, t)
185
  x = x_new.detach()
186
 
187
  return x
188
 
189
- #%%
 
 
190
  def train(rank, gpu, args):
191
  from score_sde.models.discriminator import Discriminator_small, Discriminator_large
192
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
@@ -236,37 +243,81 @@ def train(rank, gpu, args):
236
  transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
237
  ])
238
  dataset = LMDBDataset(root='/datasets/celeba-lmdb/', name='celeba', train=True, transform=train_transform)
239
-
240
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
- train_sampler = torch.utils.data.distributed.DistributedSampler(dataset,
243
- num_replicas=args.world_size,
244
- rank=rank)
245
- data_loader = torch.utils.data.DataLoader(dataset,
246
- batch_size=batch_size,
 
247
  shuffle=False,
248
  num_workers=4,
 
249
  pin_memory=True,
250
- sampler=train_sampler,
251
- drop_last = True)
252
-
253
  netG = NCSNpp(args).to(device)
 
 
 
 
254
 
255
 
256
  if args.dataset == 'cifar10' or args.dataset == 'stackmnist':
257
  netD = Discriminator_small(nc = 2*args.num_channels, ngf = args.ngf,
258
  t_emb_dim = args.t_emb_dim,
 
259
  act=nn.LeakyReLU(0.2)).to(device)
260
  else:
261
  netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
262
  t_emb_dim = args.t_emb_dim,
 
263
  act=nn.LeakyReLU(0.2)).to(device)
264
 
265
  broadcast_params(netG.parameters())
266
  broadcast_params(netD.parameters())
267
 
268
  optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
269
-
270
  optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
271
 
272
  if args.use_ema:
@@ -297,9 +348,9 @@ def train(rank, gpu, args):
297
  pos_coeff = Posterior_Coefficients(args, device)
298
  T = get_time_schedule(args, device)
299
 
300
- if args.resume:
301
- checkpoint_file = os.path.join(exp_path, 'content.pth')
302
- checkpoint = torch.load(checkpoint_file, map_location=device)
303
  init_epoch = checkpoint['epoch']
304
  epoch = init_epoch
305
  netG.load_state_dict(checkpoint['netG_dict'])
@@ -319,9 +370,22 @@ def train(rank, gpu, args):
319
 
320
 
321
  for epoch in range(init_epoch, args.num_epoch+1):
322
- train_sampler.set_epoch(epoch)
 
 
 
323
 
324
  for iteration, (x, y) in enumerate(data_loader):
 
 
 
 
 
 
 
 
 
 
325
  for p in netD.parameters():
326
  p.requires_grad = True
327
 
@@ -339,7 +403,7 @@ def train(rank, gpu, args):
339
 
340
 
341
  # train with real
342
- D_real = netD(x_t, t, x_tp1.detach()).view(-1)
343
 
344
  errD_real = F.softplus(-D_real)
345
  errD_real = errD_real.mean()
@@ -375,10 +439,10 @@ def train(rank, gpu, args):
375
  latent_z = torch.randn(batch_size, nz, device=device)
376
 
377
 
378
- x_0_predict = netG(x_tp1.detach(), t, latent_z)
379
  x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
380
 
381
- output = netD(x_pos_sample, t, x_tp1.detach()).view(-1)
382
 
383
 
384
  errD_fake = F.softplus(output)
@@ -407,11 +471,10 @@ def train(rank, gpu, args):
407
 
408
 
409
 
410
-
411
- x_0_predict = netG(x_tp1.detach(), t, latent_z)
412
  x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
413
 
414
- output = netD(x_pos_sample, t, x_tp1.detach()).view(-1)
415
 
416
 
417
  errG = F.softplus(-output)
@@ -426,7 +489,27 @@ def train(rank, gpu, args):
426
  if iteration % 100 == 0:
427
  if rank == 0:
428
  print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
429
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  if not args.no_lr_decay:
431
 
432
  schedulerG.step()
@@ -437,7 +520,7 @@ def train(rank, gpu, args):
437
  torchvision.utils.save_image(x_pos_sample, os.path.join(exp_path, 'xpos_epoch_{}.png'.format(epoch)), normalize=True)
438
 
439
  x_t_1 = torch.randn_like(real_data)
440
- fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args)
441
  torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)), normalize=True)
442
 
443
  if args.save_content:
@@ -449,6 +532,7 @@ def train(rank, gpu, args):
449
  'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
450
 
451
  torch.save(content, os.path.join(exp_path, 'content.pth'))
 
452
 
453
  if epoch % args.save_ckpt_every == 0:
454
  if args.use_ema:
@@ -462,11 +546,19 @@ def train(rank, gpu, args):
462
 
463
  def init_processes(rank, size, fn, args):
464
  """ Initialize the distributed environment. """
 
 
 
 
 
 
 
 
465
  os.environ['MASTER_ADDR'] = args.master_address
466
- os.environ['MASTER_PORT'] = '6020'
467
  torch.cuda.set_device(args.local_rank)
468
  gpu = args.local_rank
469
- dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=size)
470
  fn(rank, gpu, args)
471
  dist.barrier()
472
  cleanup()
@@ -480,7 +572,10 @@ if __name__ == '__main__':
480
  help='seed used for initialization')
481
 
482
  parser.add_argument('--resume', action='store_true',default=False)
483
-
 
 
 
484
  parser.add_argument('--image_size', type=int, default=32,
485
  help='size of image')
486
  parser.add_argument('--num_channels', type=int, default=3,
@@ -492,7 +587,7 @@ if __name__ == '__main__':
492
  help='beta_min for diffusion')
493
  parser.add_argument('--beta_max', type=float, default=20.,
494
  help='beta_max for diffusion')
495
-
496
 
497
  parser.add_argument('--num_channels_dae', type=int, default=128,
498
  help='number of initial channels in denosing model')
@@ -534,6 +629,7 @@ if __name__ == '__main__':
534
  #geenrator and training
535
  parser.add_argument('--exp', default='experiment_cifar_default', help='name of experiment')
536
  parser.add_argument('--dataset', default='cifar10', help='name of dataset')
 
537
  parser.add_argument('--nz', type=int, default=100)
538
  parser.add_argument('--num_timesteps', type=int, default=4)
539
 
@@ -577,26 +673,28 @@ if __name__ == '__main__':
577
 
578
 
579
  args = parser.parse_args()
580
- args.world_size = args.num_proc_node * args.num_process_per_node
581
- size = args.num_process_per_node
582
-
583
- if size > 1:
584
- processes = []
585
- for rank in range(size):
586
- args.local_rank = rank
587
- global_rank = rank + args.node_rank * args.num_process_per_node
588
- global_size = args.num_proc_node * args.num_process_per_node
589
- args.global_rank = global_rank
590
- print('Node rank %d, local proc %d, global proc %d' % (args.node_rank, rank, global_rank))
591
- p = Process(target=init_processes, args=(global_rank, global_size, train, args))
592
- p.start()
593
- processes.append(p)
594
-
595
- for p in processes:
596
- p.join()
597
- else:
598
- print('starting in debug mode')
 
 
599
 
600
- init_processes(0, size, train, args)
601
 
602
-
 
18
  import torchvision
19
 
20
  import torchvision.transforms as transforms
21
+ from torchvision.datasets import CIFAR10, ImageFolder
22
  from datasets_prep.lsun import LSUN
23
  from datasets_prep.stackmnist_data import StackedMNIST, _data_transforms_stacked_mnist
24
  from datasets_prep.lmdb_datasets import LMDBDataset
 
27
  from torch.multiprocessing import Process
28
  import torch.distributed as dist
29
  import shutil
30
+ import logging
31
+ import t5
32
+ def log_and_continue(exn):
33
+ logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
34
+ return True
35
 
36
  def copy_source(file, output_dir):
37
  shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))
 
177
 
178
  return sample_x_pos
179
 
180
+ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None):
181
  x = x_init
182
  with torch.no_grad():
183
  for i in reversed(range(n_time)):
 
185
 
186
  t_time = t
187
  latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
188
+ x_0 = generator(x, t_time, latent_z, cond=cond)
189
  x_new = sample_posterior(coefficients, x_0, x, t)
190
  x = x_new.detach()
191
 
192
  return x
193
 
194
+
195
+ from utils import ResampledShards2
196
+
197
  def train(rank, gpu, args):
198
  from score_sde.models.discriminator import Discriminator_small, Discriminator_large
199
  from score_sde.models.ncsnpp_generator_adagn import NCSNpp
 
243
  transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
244
  ])
245
  dataset = LMDBDataset(root='/datasets/celeba-lmdb/', name='celeba', train=True, transform=train_transform)
246
+ elif args.dataset == "image_folder":
247
+ train_transform = transforms.Compose([
248
+ transforms.Resize(args.image_size),
249
+ transforms.CenterCrop(args.image_size),
250
+ # transforms.RandomHorizontalFlip(),
251
+ transforms.ToTensor(),
252
+ transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
253
+ ])
254
+ dataset = ImageFolder(root=args.dataset_root, transform=train_transform)
255
+ elif args.dataset == 'wds':
256
+ import webdataset as wds
257
+ train_transform = transforms.Compose([
258
+ transforms.Resize(args.image_size),
259
+ transforms.CenterCrop(args.image_size),
260
+ # transforms.RandomHorizontalFlip(),
261
+ transforms.ToTensor(),
262
+ transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
263
+ ])
264
+ # pipeline = [wds.SimpleShardList(args.dataset_root)]
265
+ pipeline = [ResampledShards2(args.dataset_root)]
266
+ pipeline.extend([
267
+ wds.split_by_node,
268
+ wds.split_by_worker,
269
+ wds.tarfile_to_samples(handler=log_and_continue),
270
+ ])
271
+ pipeline.extend([
272
+ wds.decode("pilrgb", handler=log_and_continue),
273
+ wds.rename(image="jpg;png"),
274
+ wds.map_dict(image=train_transform),
275
+ wds.to_tuple("image","txt"),
276
+ wds.batched(batch_size, partial=False),
277
+ ])
278
+ dataset = wds.DataPipeline(*pipeline)
279
+ data_loader = wds.WebLoader(
280
+ dataset,
281
+ batch_size=None,
282
+ shuffle=False,
283
+ num_workers=8,
284
+ )
285
 
286
+ if args.dataset != "wds":
287
+ train_sampler = torch.utils.data.distributed.DistributedSampler(dataset,
288
+ num_replicas=args.world_size,
289
+ rank=rank)
290
+ data_loader = torch.utils.data.DataLoader(dataset,
291
+ batch_size=batch_size,
292
  shuffle=False,
293
  num_workers=4,
294
+ drop_last=True,
295
  pin_memory=True,
296
+ sampler=train_sampler,)
297
+ text_encoder = t5.T5Encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
298
+ args.cond_size = text_encoder.output_size
299
  netG = NCSNpp(args).to(device)
300
+ nb_params = 0
301
+ for param in netG.parameters():
302
+ nb_params += param.flatten().shape[0]
303
+ print("Number of generator parameters:", nb_params)
304
 
305
 
306
  if args.dataset == 'cifar10' or args.dataset == 'stackmnist':
307
  netD = Discriminator_small(nc = 2*args.num_channels, ngf = args.ngf,
308
  t_emb_dim = args.t_emb_dim,
309
+ cond_size=text_encoder.output_size,
310
  act=nn.LeakyReLU(0.2)).to(device)
311
  else:
312
  netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
313
  t_emb_dim = args.t_emb_dim,
314
+ cond_size=text_encoder.output_size,
315
  act=nn.LeakyReLU(0.2)).to(device)
316
 
317
  broadcast_params(netG.parameters())
318
  broadcast_params(netD.parameters())
319
 
320
  optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
 
321
  optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
322
 
323
  if args.use_ema:
 
348
  pos_coeff = Posterior_Coefficients(args, device)
349
  T = get_time_schedule(args, device)
350
 
351
+ checkpoint_file = os.path.join(exp_path, 'content.pth')
352
+ if args.resume and os.path.exists(checkpoint_file):
353
+ checkpoint = torch.load(checkpoint_file, map_location="cpu")
354
  init_epoch = checkpoint['epoch']
355
  epoch = init_epoch
356
  netG.load_state_dict(checkpoint['netG_dict'])
 
370
 
371
 
372
  for epoch in range(init_epoch, args.num_epoch+1):
373
+ if args.dataset == "wds":
374
+ os.environ["WDS_EPOCH"] = str(epoch)
375
+ else:
376
+ train_sampler.set_epoch(epoch)
377
 
378
  for iteration, (x, y) in enumerate(data_loader):
379
+ if args.dataset != "wds":
380
+ y = [str(yi) for yi in y.tolist()]
381
+
382
+ if args.classifier_free_guidance_proba:
383
+ u = (np.random.uniform(size=len(y)) <= args.classifier_free_guidance_proba).tolist()
384
+ y = ["" if ui else yi for yi,ui in zip(y, u)]
385
+
386
+ with torch.no_grad():
387
+ cond_pooled, cond, cond_mask = text_encoder(y, return_only_pooled=False)
388
+
389
  for p in netD.parameters():
390
  p.requires_grad = True
391
 
 
403
 
404
 
405
  # train with real
406
+ D_real = netD(x_t, t, x_tp1.detach(), cond=cond_pooled).view(-1)
407
 
408
  errD_real = F.softplus(-D_real)
409
  errD_real = errD_real.mean()
 
439
  latent_z = torch.randn(batch_size, nz, device=device)
440
 
441
 
442
+ x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
443
  x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
444
 
445
+ output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_pooled).view(-1)
446
 
447
 
448
  errD_fake = F.softplus(output)
 
471
 
472
 
473
 
474
+ x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
 
475
  x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
476
 
477
+ output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_pooled).view(-1)
478
 
479
 
480
  errG = F.softplus(-output)
 
489
  if iteration % 100 == 0:
490
  if rank == 0:
491
  print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
492
+ if iteration % 1000 == 0:
493
+ x_t_1 = torch.randn_like(real_data)
494
+ fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
495
+ if rank == 0:
496
+ torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
497
+ if args.save_content:
498
+ print('Saving content.')
499
+ content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
500
+ 'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
501
+ 'schedulerG': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
502
+ 'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
503
+
504
+ torch.save(content, os.path.join(exp_path, 'content.pth'))
505
+ torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
506
+ if args.use_ema:
507
+ optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
508
+
509
+ torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
510
+ if args.use_ema:
511
+ optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
512
+
513
  if not args.no_lr_decay:
514
 
515
  schedulerG.step()
 
520
  torchvision.utils.save_image(x_pos_sample, os.path.join(exp_path, 'xpos_epoch_{}.png'.format(epoch)), normalize=True)
521
 
522
  x_t_1 = torch.randn_like(real_data)
523
+ fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
524
  torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)), normalize=True)
525
 
526
  if args.save_content:
 
532
  'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
533
 
534
  torch.save(content, os.path.join(exp_path, 'content.pth'))
535
+ torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
536
 
537
  if epoch % args.save_ckpt_every == 0:
538
  if args.use_ema:
 
546
 
547
  def init_processes(rank, size, fn, args):
548
  """ Initialize the distributed environment. """
549
+
550
+ import os
551
+
552
+ args.rank = int(os.environ['SLURM_PROCID'])
553
+ args.world_size = int(os.getenv("SLURM_NTASKS"))
554
+ args.local_rank = int(os.environ['SLURM_LOCALID'])
555
+ print(args.rank, args.world_size)
556
+ args.master_address = os.getenv("SLURM_LAUNCH_NODE_IPADDR")
557
  os.environ['MASTER_ADDR'] = args.master_address
558
+ os.environ['MASTER_PORT'] = "12345"
559
  torch.cuda.set_device(args.local_rank)
560
  gpu = args.local_rank
561
+ dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=args.world_size)
562
  fn(rank, gpu, args)
563
  dist.barrier()
564
  cleanup()
 
572
  help='seed used for initialization')
573
 
574
  parser.add_argument('--resume', action='store_true',default=False)
575
+ parser.add_argument('--masked_mean', action='store_true',default=False)
576
+ parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
577
+ parser.add_argument('--cross_attention', action='store_true',default=False)
578
+
579
  parser.add_argument('--image_size', type=int, default=32,
580
  help='size of image')
581
  parser.add_argument('--num_channels', type=int, default=3,
 
587
  help='beta_min for diffusion')
588
  parser.add_argument('--beta_max', type=float, default=20.,
589
  help='beta_max for diffusion')
590
+ parser.add_argument('--classifier_free_guidance_proba', type=float, default=0.0)
591
 
592
  parser.add_argument('--num_channels_dae', type=int, default=128,
593
  help='number of initial channels in denosing model')
 
629
  #geenrator and training
630
  parser.add_argument('--exp', default='experiment_cifar_default', help='name of experiment')
631
  parser.add_argument('--dataset', default='cifar10', help='name of dataset')
632
+ parser.add_argument('--dataset_root', default='', help='name of dataset')
633
  parser.add_argument('--nz', type=int, default=100)
634
  parser.add_argument('--num_timesteps', type=int, default=4)
635
 
 
673
 
674
 
675
  args = parser.parse_args()
676
+ # args.world_size = args.num_proc_node * args.num_process_per_node
677
+ args.world_size = int(os.getenv("SLURM_NTASKS"))
678
+ args.rank = int(os.environ['SLURM_PROCID'])
679
+ # size = args.num_process_per_node
680
+ init_processes(args.rank, args.world_size, train, args)
681
+ # if size > 1:
682
+ # processes = []
683
+ # for rank in range(size):
684
+ # args.local_rank = rank
685
+ # global_rank = rank + args.node_rank * args.num_process_per_node
686
+ # global_size = args.num_proc_node * args.num_process_per_node
687
+ # args.global_rank = global_rank
688
+ # print('Node rank %d, local proc %d, global proc %d' % (args.node_rank, rank, global_rank))
689
+ # p = Process(target=init_processes, args=(global_rank, global_size, train, args))
690
+ # p.start()
691
+ # processes.append(p)
692
+
693
+ # for p in processes:
694
+ # p.join()
695
+ # else:
696
+ # print('starting in debug mode')
697
 
698
+ # init_processes(0, size, train, args)
699
 
700
+
utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
2
+ import braceexpand
3
+ import random
4
+ import sys
5
+ def pytorch_worker_seed():
6
+ """get dataloader worker seed from pytorch"""
7
+ worker_info = get_worker_info()
8
+ if worker_info is not None:
9
+ # favour the seed already created for pytorch dataloader workers if it exists
10
+ return worker_info.seed
11
+ # fallback to wds rank based seed
12
+ return wds.utils.pytorch_worker_seed()
13
+
14
+
15
+ class SharedEpoch:
16
+ def __init__(self, epoch: int = 0):
17
+ self.shared_epoch = Value('i', epoch)
18
+
19
+ def set_value(self, epoch):
20
+ self.shared_epoch.value = epoch
21
+
22
+ def get_value(self):
23
+ return self.shared_epoch.value
24
+
25
+
26
+
27
+ class ResampledShards2(IterableDataset):
28
+ """An iterable dataset yielding a list of urls."""
29
+
30
+ def __init__(
31
+ self,
32
+ urls,
33
+ nshards=sys.maxsize,
34
+ worker_seed=None,
35
+ deterministic=False,
36
+ epoch=-1,
37
+ ):
38
+ """Sample shards from the shard list with replacement.
39
+
40
+ :param urls: a list of URLs as a Python list or brace notation string
41
+ """
42
+ super().__init__()
43
+ #urls = wds.shardlists.expand_urls(urls)
44
+ urls = list(braceexpand.braceexpand(urls))
45
+ self.urls = urls
46
+ assert isinstance(self.urls[0], str)
47
+ self.nshards = nshards
48
+ self.rng = random.Random()
49
+ self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
50
+ self.deterministic = deterministic
51
+ self.epoch = epoch
52
+
53
+ def __iter__(self):
54
+ """Return an iterator over the shards."""
55
+ if isinstance(self.epoch, SharedEpoch):
56
+ epoch = self.epoch.get_value()
57
+ else:
58
+ # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
59
+ # situation as different workers may wrap at different times (or not at all).
60
+ self.epoch += 1
61
+ epoch = self.epoch
62
+ if self.deterministic:
63
+ # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed
64
+ self.rng.seed(self.worker_seed() + epoch)
65
+ for _ in range(self.nshards):
66
+ yield dict(url=self.rng.choice(self.urls))
67
+