Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
•
c81908d
1
Parent(s):
c334626
text to image support
Browse files- pytorch_fid/fid_score.py +2 -2
- run.py +127 -0
- score_sde/models/discriminator.py +15 -13
- score_sde/models/layers.py +271 -1
- score_sde/models/ncsnpp_generator_adagn.py +29 -10
- scripts/fid.sh +0 -0
- scripts/init.sh +34 -0
- scripts/init_2020.sh +69 -0
- scripts/init_2022.sh +34 -0
- scripts/run_jurecadc_ddp.sh +17 -0
- scripts/run_jusuf_ddp.sh +14 -0
- scripts/run_juwelsbooster_ddp.sh +17 -0
- t5.py +99 -0
- test_ddgan.py +146 -20
- train_ddgan.py +150 -52
- utils.py +67 -0
pytorch_fid/fid_score.py
CHANGED
@@ -140,7 +140,7 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', resize
|
|
140 |
batch_size=batch_size,
|
141 |
shuffle=False,
|
142 |
drop_last=False,
|
143 |
-
num_workers=
|
144 |
|
145 |
pred_arr = np.empty((len(files), dims))
|
146 |
|
@@ -148,7 +148,7 @@ def get_activations(files, model, batch_size=50, dims=2048, device='cpu', resize
|
|
148 |
|
149 |
for batch in tqdm(dataloader):
|
150 |
batch = batch.to(device)
|
151 |
-
|
152 |
with torch.no_grad():
|
153 |
pred = model(batch)[0]
|
154 |
|
|
|
140 |
batch_size=batch_size,
|
141 |
shuffle=False,
|
142 |
drop_last=False,
|
143 |
+
num_workers=8)
|
144 |
|
145 |
pred_arr = np.empty((len(files), dims))
|
146 |
|
|
|
148 |
|
149 |
for batch in tqdm(dataloader):
|
150 |
batch = batch.to(device)
|
151 |
+
print(batch.shape, batch.min(), batch.max)
|
152 |
with torch.no_grad():
|
153 |
pred = model(batch)[0]
|
154 |
|
run.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from clize import run
|
3 |
+
from glob import glob
|
4 |
+
from subprocess import call
|
5 |
+
|
6 |
+
def base():
|
7 |
+
return {
|
8 |
+
"slurm":{
|
9 |
+
"t": 360,
|
10 |
+
"N": 2,
|
11 |
+
"n": 8,
|
12 |
+
},
|
13 |
+
"model":{
|
14 |
+
"dataset" :"wds",
|
15 |
+
"dataset_root": "/p/scratch/ccstdl/cherti1/CC12M/{00000..01099}.tar",
|
16 |
+
"image_size": 256,
|
17 |
+
"num_channels": 3,
|
18 |
+
"num_channels_dae": 128,
|
19 |
+
"ch_mult": "1 1 2 2 4 4",
|
20 |
+
"num_timesteps": 4,
|
21 |
+
"num_res_blocks": 2,
|
22 |
+
"batch_size": 8,
|
23 |
+
"num_epoch": 1000,
|
24 |
+
"ngf": 64,
|
25 |
+
"embedding_type": "positional",
|
26 |
+
"use_ema": "",
|
27 |
+
"ema_decay": 0.999,
|
28 |
+
"r1_gamma": 1.0,
|
29 |
+
"z_emb_dim": 256,
|
30 |
+
"lr_d": 1e-4,
|
31 |
+
"lr_g": 1.6e-4,
|
32 |
+
"lazy_reg": 10,
|
33 |
+
"save_content": "",
|
34 |
+
"save_ckpt_every": 1,
|
35 |
+
"masked_mean": "",
|
36 |
+
"resume": "",
|
37 |
+
}
|
38 |
+
}
|
39 |
+
def ddgan_cc12m_v2():
|
40 |
+
cfg = base()
|
41 |
+
cfg['slurm']['N'] = 2
|
42 |
+
cfg['slurm']['n'] = 8
|
43 |
+
return cfg
|
44 |
+
|
45 |
+
def ddgan_cc12m_v6():
|
46 |
+
cfg = base()
|
47 |
+
cfg['model']['text_encoder'] = "google/t5-v1_1-large"
|
48 |
+
return cfg
|
49 |
+
|
50 |
+
def ddgan_cc12m_v7():
|
51 |
+
cfg = base()
|
52 |
+
cfg['model']['classifier_free_guidance_proba'] = 0.2
|
53 |
+
cfg['slurm']['N'] = 2
|
54 |
+
cfg['slurm']['n'] = 8
|
55 |
+
return cfg
|
56 |
+
|
57 |
+
def ddgan_cc12m_v8():
|
58 |
+
cfg = base()
|
59 |
+
cfg['model']['text_encoder'] = "google/t5-v1_1-large"
|
60 |
+
cfg['model']['classifier_free_guidance_proba'] = 0.2
|
61 |
+
return cfg
|
62 |
+
|
63 |
+
def ddgan_cc12m_v9():
|
64 |
+
cfg = base()
|
65 |
+
cfg['model']['text_encoder'] = "google/t5-v1_1-large"
|
66 |
+
cfg['model']['classifier_free_guidance_proba'] = 0.2
|
67 |
+
cfg['model']['num_channels_dae'] = 320
|
68 |
+
cfg['model']['image_size'] = 64
|
69 |
+
cfg['model']['batch_size'] = 1
|
70 |
+
return cfg
|
71 |
+
|
72 |
+
|
73 |
+
def ddgan_cc12m_v11():
|
74 |
+
cfg = base()
|
75 |
+
cfg['model']['text_encoder'] = "google/t5-v1_1-large"
|
76 |
+
cfg['model']['classifier_free_guidance_proba'] = 0.2
|
77 |
+
cfg['model']['cross_attention'] = ""
|
78 |
+
return cfg
|
79 |
+
|
80 |
+
models = [
|
81 |
+
ddgan_cc12m_v2,
|
82 |
+
ddgan_cc12m_v6,
|
83 |
+
ddgan_cc12m_v7,
|
84 |
+
ddgan_cc12m_v8,
|
85 |
+
ddgan_cc12m_v9,
|
86 |
+
ddgan_cc12m_v11,
|
87 |
+
|
88 |
+
]
|
89 |
+
def get_model(model_name):
|
90 |
+
for model in models:
|
91 |
+
if model.__name__ == model_name:
|
92 |
+
return model()
|
93 |
+
|
94 |
+
|
95 |
+
def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir=""):
|
96 |
+
|
97 |
+
cfg = get_model(model_name)
|
98 |
+
model = cfg['model']
|
99 |
+
if epoch is None:
|
100 |
+
paths = glob('./saved_info/dd_gan/{}/{}/netG_*.pth'.format(model["dataset"], model_name))
|
101 |
+
epoch = max(
|
102 |
+
[int(os.path.basename(path).replace(".pth", "").split("_")[1]) for path in paths]
|
103 |
+
)
|
104 |
+
args = {}
|
105 |
+
args['exp'] = model_name
|
106 |
+
args['image_size'] = model['image_size']
|
107 |
+
args['num_channels'] = model['num_channels']
|
108 |
+
args['dataset'] = model['dataset']
|
109 |
+
args['num_channels_dae'] = model['num_channels_dae']
|
110 |
+
args['ch_mult'] = model['ch_mult']
|
111 |
+
args['num_timesteps'] = model['num_timesteps']
|
112 |
+
args['num_res_blocks'] = model['num_res_blocks']
|
113 |
+
args['batch_size'] = model['batch_size'] if batch_size is None else batch_size
|
114 |
+
args['epoch'] = epoch
|
115 |
+
args['cond_text'] = f'"{cond_text}"'
|
116 |
+
args['text_encoder'] = model.get("text_encoder")
|
117 |
+
args['cross_attention'] = model.get("cross_attention")
|
118 |
+
args['guidance_scale'] = guidance_scale
|
119 |
+
|
120 |
+
if fid:
|
121 |
+
args['compute_fid'] = ''
|
122 |
+
args['real_img_dir'] = real_img_dir
|
123 |
+
cmd = "python test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
|
124 |
+
print(cmd)
|
125 |
+
call(cmd, shell=True)
|
126 |
+
|
127 |
+
run([test])
|
score_sde/models/discriminator.py
CHANGED
@@ -96,11 +96,12 @@ class DownConvBlock(nn.Module):
|
|
96 |
class Discriminator_small(nn.Module):
|
97 |
"""A time-dependent discriminator for small images (CIFAR10, StackMNIST)."""
|
98 |
|
99 |
-
def __init__(self, nc = 3, ngf = 64, t_emb_dim = 128, act=nn.LeakyReLU(0.2)):
|
100 |
super().__init__()
|
101 |
# Gaussian random feature embedding layer for time
|
102 |
self.act = act
|
103 |
-
|
|
|
104 |
|
105 |
self.t_embed = TimestepEmbedding(
|
106 |
embedding_dim=t_emb_dim,
|
@@ -131,10 +132,11 @@ class Discriminator_small(nn.Module):
|
|
131 |
self.stddev_feat = 1
|
132 |
|
133 |
|
134 |
-
def forward(self, x, t, x_t):
|
135 |
-
t_embed = self.
|
136 |
-
|
137 |
-
|
|
|
138 |
input_x = torch.cat((x, x_t), dim = 1)
|
139 |
|
140 |
h0 = self.start_conv(input_x)
|
@@ -159,10 +161,9 @@ class Discriminator_small(nn.Module):
|
|
159 |
|
160 |
out = self.final_conv(out)
|
161 |
out = self.act(out)
|
162 |
-
|
163 |
|
164 |
out = out.view(out.shape[0], out.shape[1], -1).sum(2)
|
165 |
-
out = self.end_linear(out)
|
166 |
|
167 |
return out
|
168 |
|
@@ -170,9 +171,10 @@ class Discriminator_small(nn.Module):
|
|
170 |
class Discriminator_large(nn.Module):
|
171 |
"""A time-dependent discriminator for large images (CelebA, LSUN)."""
|
172 |
|
173 |
-
def __init__(self, nc = 1, ngf = 32, t_emb_dim = 128, act=nn.LeakyReLU(0.2)):
|
174 |
super().__init__()
|
175 |
# Gaussian random feature embedding layer for time
|
|
|
176 |
self.act = act
|
177 |
|
178 |
self.t_embed = TimestepEmbedding(
|
@@ -202,8 +204,9 @@ class Discriminator_large(nn.Module):
|
|
202 |
self.stddev_feat = 1
|
203 |
|
204 |
|
205 |
-
def forward(self, x, t, x_t):
|
206 |
-
t_embed = self.
|
|
|
207 |
|
208 |
input_x = torch.cat((x, x_t), dim = 1)
|
209 |
|
@@ -233,7 +236,6 @@ class Discriminator_large(nn.Module):
|
|
233 |
out = self.act(out)
|
234 |
|
235 |
out = out.view(out.shape[0], out.shape[1], -1).sum(2)
|
236 |
-
out = self.end_linear(out)
|
237 |
-
|
238 |
return out
|
239 |
|
|
|
96 |
class Discriminator_small(nn.Module):
|
97 |
"""A time-dependent discriminator for small images (CIFAR10, StackMNIST)."""
|
98 |
|
99 |
+
def __init__(self, nc = 3, ngf = 64, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768):
|
100 |
super().__init__()
|
101 |
# Gaussian random feature embedding layer for time
|
102 |
self.act = act
|
103 |
+
self.cond_proj = nn.Linear(cond_size, ngf*8)
|
104 |
+
# self.cond_proj.weight.data = default_initializer()(self.cond_proj.weight.shape)
|
105 |
|
106 |
self.t_embed = TimestepEmbedding(
|
107 |
embedding_dim=t_emb_dim,
|
|
|
132 |
self.stddev_feat = 1
|
133 |
|
134 |
|
135 |
+
def forward(self, x, t, x_t, cond=None):
|
136 |
+
t_embed = self.t_embed(t)
|
137 |
+
# if cond is not None:
|
138 |
+
# t_embed = t_embed + self.cond_proj(cond)
|
139 |
+
t_embed = self.act(t_embed)
|
140 |
input_x = torch.cat((x, x_t), dim = 1)
|
141 |
|
142 |
h0 = self.start_conv(input_x)
|
|
|
161 |
|
162 |
out = self.final_conv(out)
|
163 |
out = self.act(out)
|
|
|
164 |
|
165 |
out = out.view(out.shape[0], out.shape[1], -1).sum(2)
|
166 |
+
out = self.end_linear(out) + (self.cond_proj(cond) * out).sum(dim=1, keepdim=True)
|
167 |
|
168 |
return out
|
169 |
|
|
|
171 |
class Discriminator_large(nn.Module):
|
172 |
"""A time-dependent discriminator for large images (CelebA, LSUN)."""
|
173 |
|
174 |
+
def __init__(self, nc = 1, ngf = 32, t_emb_dim = 128, act=nn.LeakyReLU(0.2), cond_size=768):
|
175 |
super().__init__()
|
176 |
# Gaussian random feature embedding layer for time
|
177 |
+
self.cond_proj = nn.Linear(cond_size, ngf*8)
|
178 |
self.act = act
|
179 |
|
180 |
self.t_embed = TimestepEmbedding(
|
|
|
204 |
self.stddev_feat = 1
|
205 |
|
206 |
|
207 |
+
def forward(self, x, t, x_t, cond=None):
|
208 |
+
t_embed = self.t_embed(t)
|
209 |
+
t_embed = self.act(t_embed)
|
210 |
|
211 |
input_x = torch.cat((x, x_t), dim = 1)
|
212 |
|
|
|
236 |
out = self.act(out)
|
237 |
|
238 |
out = out.view(out.shape[0], out.shape[1], -1).sum(2)
|
239 |
+
out = self.end_linear(out) + (self.cond_proj(cond) * out).sum(dim=1, keepdim=True)
|
|
|
240 |
return out
|
241 |
|
score_sde/models/layers.py
CHANGED
@@ -538,6 +538,276 @@ class AttnBlock(nn.Module):
|
|
538 |
return x + h
|
539 |
|
540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
class Upsample(nn.Module):
|
542 |
def __init__(self, channels, with_conv=False):
|
543 |
super().__init__()
|
@@ -616,4 +886,4 @@ class ResnetBlockDDPM(nn.Module):
|
|
616 |
x = self.Conv_2(x)
|
617 |
else:
|
618 |
x = self.NIN_0(x)
|
619 |
-
return x + h
|
|
|
538 |
return x + h
|
539 |
|
540 |
|
541 |
+
class CondAttnBlock(nn.Module):
|
542 |
+
"""Channel-wise self-attention block."""
|
543 |
+
def __init__(self, channels, context_dim, dim_head=64, heads=8, norm_context=False, cosine_sim_attn=False):
|
544 |
+
super().__init__()
|
545 |
+
self.GroupNorm_0 = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6)
|
546 |
+
self.ca = CrossAttention(
|
547 |
+
channels,
|
548 |
+
context_dim=context_dim,
|
549 |
+
dim_head=dim_head,
|
550 |
+
heads=heads,
|
551 |
+
norm_context=norm_context,
|
552 |
+
cosine_sim_attn=cosine_sim_attn,
|
553 |
+
)
|
554 |
+
|
555 |
+
def forward(self, x, cond, mask=None):
|
556 |
+
B, C, H, W = x.shape
|
557 |
+
h = self.GroupNorm_0(x)
|
558 |
+
h = h.view(B, C, H*W)
|
559 |
+
h = h.permute(0,2,1)
|
560 |
+
h = h.contiguous()
|
561 |
+
h_new = self.ca(h, cond, mask=mask)
|
562 |
+
h_new = h_new.permute(0,2,1)
|
563 |
+
h_new = h_new.contiguous()
|
564 |
+
h_new = h_new.view(B, C, H, W)
|
565 |
+
return x + h_new
|
566 |
+
|
567 |
+
from torch import einsum
|
568 |
+
from einops import rearrange, repeat, reduce
|
569 |
+
from einops.layers.torch import Rearrange, Reduce
|
570 |
+
from einops_exts import rearrange_many, repeat_many, check_shape
|
571 |
+
from einops_exts.torch import EinopsToAndFrom
|
572 |
+
|
573 |
+
def default(val, d):
|
574 |
+
if exists(val):
|
575 |
+
return val
|
576 |
+
return d() if callable(d) else d
|
577 |
+
|
578 |
+
class Identity(nn.Module):
|
579 |
+
def __init__(self, *args, **kwargs):
|
580 |
+
super().__init__()
|
581 |
+
|
582 |
+
def forward(self, x, *args, **kwargs):
|
583 |
+
return x
|
584 |
+
|
585 |
+
|
586 |
+
class CrossAttention(nn.Module):
|
587 |
+
def __init__(
|
588 |
+
self,
|
589 |
+
dim,
|
590 |
+
*,
|
591 |
+
context_dim = None,
|
592 |
+
dim_head = 64,
|
593 |
+
heads = 8,
|
594 |
+
norm_context = False,
|
595 |
+
cosine_sim_attn = False
|
596 |
+
):
|
597 |
+
super().__init__()
|
598 |
+
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
|
599 |
+
self.cosine_sim_attn = cosine_sim_attn
|
600 |
+
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
|
601 |
+
|
602 |
+
self.heads = heads
|
603 |
+
inner_dim = dim_head * heads
|
604 |
+
|
605 |
+
context_dim = default(context_dim, dim)
|
606 |
+
|
607 |
+
self.norm = nn.LayerNorm(dim)
|
608 |
+
self.norm_context = nn.LayerNorm(context_dim) if norm_context else Identity()
|
609 |
+
|
610 |
+
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
611 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
612 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
613 |
+
|
614 |
+
self.to_out = nn.Sequential(
|
615 |
+
nn.Linear(inner_dim, dim, bias = False),
|
616 |
+
nn.LayerNorm(dim)
|
617 |
+
)
|
618 |
+
|
619 |
+
def forward(self, x, context, mask = None):
|
620 |
+
b, n, device = *x.shape[:2], x.device
|
621 |
+
|
622 |
+
x = self.norm(x)
|
623 |
+
context = self.norm_context(context)
|
624 |
+
|
625 |
+
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
|
626 |
+
|
627 |
+
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
|
628 |
+
|
629 |
+
# add null key / value for classifier free guidance in prior net
|
630 |
+
|
631 |
+
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
|
632 |
+
|
633 |
+
k = torch.cat((nk, k), dim = -2)
|
634 |
+
v = torch.cat((nv, v), dim = -2)
|
635 |
+
|
636 |
+
q = q * self.scale
|
637 |
+
|
638 |
+
# cosine sim attention
|
639 |
+
|
640 |
+
if self.cosine_sim_attn:
|
641 |
+
q, k = map(l2norm, (q, k))
|
642 |
+
|
643 |
+
# similarities
|
644 |
+
|
645 |
+
sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale
|
646 |
+
|
647 |
+
# masking
|
648 |
+
|
649 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
650 |
+
|
651 |
+
if exists(mask):
|
652 |
+
mask = F.pad(mask, (1, 0), value = True)
|
653 |
+
mask = rearrange(mask, 'b j -> b 1 1 j')
|
654 |
+
sim = sim.masked_fill(~mask, max_neg_value)
|
655 |
+
|
656 |
+
attn = sim.softmax(dim = -1, dtype = torch.float32)
|
657 |
+
|
658 |
+
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
659 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
660 |
+
return self.to_out(out)
|
661 |
+
|
662 |
+
|
663 |
+
class PerceiverAttention(nn.Module):
|
664 |
+
def __init__(
|
665 |
+
self,
|
666 |
+
*,
|
667 |
+
dim,
|
668 |
+
dim_head = 64,
|
669 |
+
heads = 8,
|
670 |
+
cosine_sim_attn = False
|
671 |
+
):
|
672 |
+
super().__init__()
|
673 |
+
self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1
|
674 |
+
self.cosine_sim_attn = cosine_sim_attn
|
675 |
+
self.cosine_sim_scale = 16 if cosine_sim_attn else 1
|
676 |
+
|
677 |
+
self.heads = heads
|
678 |
+
inner_dim = dim_head * heads
|
679 |
+
|
680 |
+
self.norm = nn.LayerNorm(dim)
|
681 |
+
self.norm_latents = nn.LayerNorm(dim)
|
682 |
+
|
683 |
+
self.to_q = nn.Linear(dim, inner_dim, bias = False)
|
684 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
|
685 |
+
|
686 |
+
self.to_out = nn.Sequential(
|
687 |
+
nn.Linear(inner_dim, dim, bias = False),
|
688 |
+
nn.LayerNorm(dim)
|
689 |
+
)
|
690 |
+
|
691 |
+
def forward(self, x, latents, mask = None):
|
692 |
+
x = self.norm(x)
|
693 |
+
latents = self.norm_latents(latents)
|
694 |
+
|
695 |
+
b, h = x.shape[0], self.heads
|
696 |
+
|
697 |
+
q = self.to_q(latents)
|
698 |
+
|
699 |
+
# the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
|
700 |
+
kv_input = torch.cat((x, latents), dim = -2)
|
701 |
+
k, v = self.to_kv(kv_input).chunk(2, dim = -1)
|
702 |
+
|
703 |
+
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
|
704 |
+
|
705 |
+
q = q * self.scale
|
706 |
+
|
707 |
+
# cosine sim attention
|
708 |
+
|
709 |
+
if self.cosine_sim_attn:
|
710 |
+
q, k = map(l2norm, (q, k))
|
711 |
+
|
712 |
+
# similarities and masking
|
713 |
+
|
714 |
+
sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale
|
715 |
+
|
716 |
+
if exists(mask):
|
717 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
718 |
+
mask = F.pad(mask, (0, latents.shape[-2]), value = True)
|
719 |
+
mask = rearrange(mask, 'b j -> b 1 1 j')
|
720 |
+
sim = sim.masked_fill(~mask, max_neg_value)
|
721 |
+
|
722 |
+
# attention
|
723 |
+
|
724 |
+
attn = sim.softmax(dim = -1)
|
725 |
+
|
726 |
+
out = einsum('... i j, ... j d -> ... i d', attn, v)
|
727 |
+
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
|
728 |
+
return self.to_out(out)
|
729 |
+
|
730 |
+
|
731 |
+
def FeedForward(dim, mult = 2):
|
732 |
+
hidden_dim = int(dim * mult)
|
733 |
+
return nn.Sequential(
|
734 |
+
nn.LayerNorm(dim),
|
735 |
+
nn.Linear(dim, hidden_dim, bias = False),
|
736 |
+
nn.GELU(),
|
737 |
+
nn.LayerNorm(hidden_dim),
|
738 |
+
nn.Linear(hidden_dim, dim, bias = False)
|
739 |
+
)
|
740 |
+
|
741 |
+
def exists(val):
|
742 |
+
return val is not None
|
743 |
+
|
744 |
+
|
745 |
+
def masked_mean(t, *, dim, mask = None):
|
746 |
+
if not exists(mask):
|
747 |
+
return t.mean(dim = dim)
|
748 |
+
denom = mask.sum(dim = dim, keepdim = True)
|
749 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
750 |
+
masked_t = t.masked_fill(~mask, 0.)
|
751 |
+
|
752 |
+
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
|
753 |
+
|
754 |
+
class PerceiverResampler(nn.Module):
|
755 |
+
def __init__(
|
756 |
+
self,
|
757 |
+
*,
|
758 |
+
dim,
|
759 |
+
depth,
|
760 |
+
dim_head = 64,
|
761 |
+
heads = 8,
|
762 |
+
num_latents = 64,
|
763 |
+
num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
|
764 |
+
max_seq_len = 512,
|
765 |
+
ff_mult = 4,
|
766 |
+
cosine_sim_attn = False
|
767 |
+
):
|
768 |
+
super().__init__()
|
769 |
+
self.pos_emb = nn.Embedding(max_seq_len, dim)
|
770 |
+
|
771 |
+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
|
772 |
+
|
773 |
+
self.to_latents_from_mean_pooled_seq = None
|
774 |
+
|
775 |
+
if num_latents_mean_pooled > 0:
|
776 |
+
self.to_latents_from_mean_pooled_seq = nn.Sequential(
|
777 |
+
nn.LayerNorm(dim),
|
778 |
+
nn.Linear(dim, dim * num_latents_mean_pooled),
|
779 |
+
Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
|
780 |
+
)
|
781 |
+
|
782 |
+
self.layers = nn.ModuleList([])
|
783 |
+
for _ in range(depth):
|
784 |
+
self.layers.append(nn.ModuleList([
|
785 |
+
PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn),
|
786 |
+
FeedForward(dim = dim, mult = ff_mult)
|
787 |
+
]))
|
788 |
+
|
789 |
+
def forward(self, x, mask = None):
|
790 |
+
n, device = x.shape[1], x.device
|
791 |
+
pos_emb = self.pos_emb(torch.arange(n, device = device))
|
792 |
+
|
793 |
+
x_with_pos = x + pos_emb
|
794 |
+
|
795 |
+
latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
|
796 |
+
|
797 |
+
if exists(self.to_latents_from_mean_pooled_seq):
|
798 |
+
meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
|
799 |
+
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
800 |
+
latents = torch.cat((meanpooled_latents, latents), dim = -2)
|
801 |
+
|
802 |
+
for attn, ff in self.layers:
|
803 |
+
latents = attn(x_with_pos, latents, mask = mask) + latents
|
804 |
+
latents = ff(latents) + latents
|
805 |
+
|
806 |
+
return latents
|
807 |
+
|
808 |
+
|
809 |
+
|
810 |
+
|
811 |
class Upsample(nn.Module):
|
812 |
def __init__(self, channels, with_conv=False):
|
813 |
super().__init__()
|
|
|
886 |
x = self.Conv_2(x)
|
887 |
else:
|
888 |
x = self.NIN_0(x)
|
889 |
+
return x + h
|
score_sde/models/ncsnpp_generator_adagn.py
CHANGED
@@ -66,8 +66,10 @@ class NCSNpp(nn.Module):
|
|
66 |
self.not_use_tanh = config.not_use_tanh
|
67 |
self.act = act = nn.SiLU()
|
68 |
self.z_emb_dim = z_emb_dim = config.z_emb_dim
|
69 |
-
|
70 |
self.nf = nf = config.num_channels_dae
|
|
|
|
|
|
|
71 |
ch_mult = config.ch_mult
|
72 |
self.num_res_blocks = num_res_blocks = config.num_res_blocks
|
73 |
self.attn_resolutions = attn_resolutions = config.attn_resolutions
|
@@ -115,10 +117,12 @@ class NCSNpp(nn.Module):
|
|
115 |
modules.append(nn.Linear(nf * 4, nf * 4))
|
116 |
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
117 |
nn.init.zeros_(modules[-1].bias)
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
122 |
|
123 |
Upsample = functools.partial(layerspp.Upsample,
|
124 |
with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
@@ -277,7 +281,7 @@ class NCSNpp(nn.Module):
|
|
277 |
self.z_transform = nn.Sequential(*mapping_layers)
|
278 |
|
279 |
|
280 |
-
def forward(self, x, time_cond, z):
|
281 |
# timestep/noise_level embedding; only for continuous training
|
282 |
zemb = self.z_transform(z)
|
283 |
modules = self.all_modules
|
@@ -296,9 +300,14 @@ class NCSNpp(nn.Module):
|
|
296 |
|
297 |
else:
|
298 |
raise ValueError(f'embedding type {self.embedding_type} unknown.')
|
299 |
-
|
|
|
|
|
|
|
300 |
if self.conditional:
|
301 |
temb = modules[m_idx](temb)
|
|
|
|
|
302 |
m_idx += 1
|
303 |
temb = modules[m_idx](self.act(temb))
|
304 |
m_idx += 1
|
@@ -322,7 +331,10 @@ class NCSNpp(nn.Module):
|
|
322 |
h = modules[m_idx](hs[-1], temb, zemb)
|
323 |
m_idx += 1
|
324 |
if h.shape[-1] in self.attn_resolutions:
|
325 |
-
|
|
|
|
|
|
|
326 |
m_idx += 1
|
327 |
|
328 |
hs.append(h)
|
@@ -354,7 +366,10 @@ class NCSNpp(nn.Module):
|
|
354 |
h = hs[-1]
|
355 |
h = modules[m_idx](h, temb, zemb)
|
356 |
m_idx += 1
|
357 |
-
|
|
|
|
|
|
|
358 |
m_idx += 1
|
359 |
h = modules[m_idx](h, temb, zemb)
|
360 |
m_idx += 1
|
@@ -368,7 +383,10 @@ class NCSNpp(nn.Module):
|
|
368 |
m_idx += 1
|
369 |
|
370 |
if h.shape[-1] in self.attn_resolutions:
|
371 |
-
|
|
|
|
|
|
|
372 |
m_idx += 1
|
373 |
|
374 |
if self.progressive != 'none':
|
@@ -429,3 +447,4 @@ class NCSNpp(nn.Module):
|
|
429 |
return torch.tanh(h)
|
430 |
else:
|
431 |
return h
|
|
|
|
66 |
self.not_use_tanh = config.not_use_tanh
|
67 |
self.act = act = nn.SiLU()
|
68 |
self.z_emb_dim = z_emb_dim = config.z_emb_dim
|
|
|
69 |
self.nf = nf = config.num_channels_dae
|
70 |
+
self.cond_proj = nn.Linear(config.cond_size, self.nf*4)
|
71 |
+
self.cond_proj.weight.data = default_initializer()(self.cond_proj.weight.shape)
|
72 |
+
|
73 |
ch_mult = config.ch_mult
|
74 |
self.num_res_blocks = num_res_blocks = config.num_res_blocks
|
75 |
self.attn_resolutions = attn_resolutions = config.attn_resolutions
|
|
|
117 |
modules.append(nn.Linear(nf * 4, nf * 4))
|
118 |
modules[-1].weight.data = default_initializer()(modules[-1].weight.shape)
|
119 |
nn.init.zeros_(modules[-1].bias)
|
120 |
+
if config.cross_attention:
|
121 |
+
AttnBlock = functools.partial(layers.CondAttnBlock, context_dim=config.cond_size)
|
122 |
+
else:
|
123 |
+
AttnBlock = functools.partial(layerspp.AttnBlockpp,
|
124 |
+
init_scale=init_scale,
|
125 |
+
skip_rescale=skip_rescale)
|
126 |
|
127 |
Upsample = functools.partial(layerspp.Upsample,
|
128 |
with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel)
|
|
|
281 |
self.z_transform = nn.Sequential(*mapping_layers)
|
282 |
|
283 |
|
284 |
+
def forward(self, x, time_cond, z, cond=None):
|
285 |
# timestep/noise_level embedding; only for continuous training
|
286 |
zemb = self.z_transform(z)
|
287 |
modules = self.all_modules
|
|
|
300 |
|
301 |
else:
|
302 |
raise ValueError(f'embedding type {self.embedding_type} unknown.')
|
303 |
+
|
304 |
+
if cond is not None:
|
305 |
+
cond_pooled, cond, cond_mask = cond
|
306 |
+
|
307 |
if self.conditional:
|
308 |
temb = modules[m_idx](temb)
|
309 |
+
if cond is not None:
|
310 |
+
temb = temb + self.cond_proj(cond_pooled)
|
311 |
m_idx += 1
|
312 |
temb = modules[m_idx](self.act(temb))
|
313 |
m_idx += 1
|
|
|
331 |
h = modules[m_idx](hs[-1], temb, zemb)
|
332 |
m_idx += 1
|
333 |
if h.shape[-1] in self.attn_resolutions:
|
334 |
+
if type(modules[m_idx]) == layers.CondAttnBlock:
|
335 |
+
h = modules[m_idx](h, cond, cond_mask)
|
336 |
+
else:
|
337 |
+
h = modules[m_idx](h)
|
338 |
m_idx += 1
|
339 |
|
340 |
hs.append(h)
|
|
|
366 |
h = hs[-1]
|
367 |
h = modules[m_idx](h, temb, zemb)
|
368 |
m_idx += 1
|
369 |
+
if type(modules[m_idx]) == layers.CondAttnBlock:
|
370 |
+
h = modules[m_idx](h, cond, cond_mask)
|
371 |
+
else:
|
372 |
+
h = modules[m_idx](h)
|
373 |
m_idx += 1
|
374 |
h = modules[m_idx](h, temb, zemb)
|
375 |
m_idx += 1
|
|
|
383 |
m_idx += 1
|
384 |
|
385 |
if h.shape[-1] in self.attn_resolutions:
|
386 |
+
if type(modules[m_idx]) == layers.CondAttnBlock:
|
387 |
+
h = modules[m_idx](h, cond, cond_mask)
|
388 |
+
else:
|
389 |
+
h = modules[m_idx](h)
|
390 |
m_idx += 1
|
391 |
|
392 |
if self.progressive != 'none':
|
|
|
447 |
return torch.tanh(h)
|
448 |
else:
|
449 |
return h
|
450 |
+
|
scripts/fid.sh
ADDED
File without changes
|
scripts/init.sh
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
machine=$(cat /etc/FZJ/systemname)
|
2 |
+
if [[ "$machine" == jurecadc ]]; then
|
3 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
4 |
+
ml purge
|
5 |
+
ml use $OTHERSTAGES
|
6 |
+
ml Stages/2022
|
7 |
+
ml GCC/11.2.0
|
8 |
+
ml OpenMPI/4.1.2
|
9 |
+
ml CUDA/11.5
|
10 |
+
ml cuDNN/8.3.1.22-CUDA-11.5
|
11 |
+
ml NCCL/2.12.7-1-CUDA-11.5
|
12 |
+
ml PyTorch/1.11-CUDA-11.5
|
13 |
+
ml Horovod/0.24
|
14 |
+
ml torchvision/0.12.0
|
15 |
+
source /p/project/covidnetx/environments/jureca_2022/bin/activate
|
16 |
+
fi
|
17 |
+
if [[ "$machine" == juwelsbooster ]]; then
|
18 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
19 |
+
ml purge
|
20 |
+
ml use $OTHERSTAGES
|
21 |
+
ml Stages/2022
|
22 |
+
ml GCC/11.2.0
|
23 |
+
ml OpenMPI/4.1.2
|
24 |
+
ml CUDA/11.5
|
25 |
+
ml cuDNN/8.3.1.22-CUDA-11.5
|
26 |
+
ml NCCL/2.12.7-1-CUDA-11.5
|
27 |
+
ml PyTorch/1.11-CUDA-11.5
|
28 |
+
ml Horovod/0.24
|
29 |
+
ml torchvision/0.12.0
|
30 |
+
source /p/project/covidnetx/environments/juwels_booster_2022/bin/activate
|
31 |
+
fi
|
32 |
+
if [[ "$machine" == jusuf ]]; then
|
33 |
+
echo not supported
|
34 |
+
fi
|
scripts/init_2020.sh
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
machine=$(cat /etc/FZJ/systemname)
|
2 |
+
if [[ "$machine" == jurecadc ]]; then
|
3 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
4 |
+
#ml use $OTHERSTAGES
|
5 |
+
#ml Stages/2020
|
6 |
+
#ml GCC/9.3.0
|
7 |
+
#ml OpenMPI/4.1.0rc1
|
8 |
+
#ml CUDA/11.0
|
9 |
+
#ml cuDNN/8.0.2.39-CUDA-11.0
|
10 |
+
#ml NCCL/2.8.3-1-CUDA-11.0
|
11 |
+
#ml PyTorch
|
12 |
+
#ml Horovod/0.20.3-Python-3.8.5
|
13 |
+
#ml scikit
|
14 |
+
#source /p/project/covidnetx/environments/jureca/bin/activate
|
15 |
+
ml purge
|
16 |
+
ml use $OTHERSTAGES
|
17 |
+
ml Stages/2020
|
18 |
+
ml GCC/10.3.0
|
19 |
+
ml OpenMPI/4.1.1
|
20 |
+
ml Horovod/0.23.0-Python-3.8.5
|
21 |
+
ml scikit
|
22 |
+
source /p/project/covidnetx/environments/jureca/bin/activate
|
23 |
+
fi
|
24 |
+
if [[ "$machine" == juwelsbooster ]]; then
|
25 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
26 |
+
#ml use $OTHERSTAGES
|
27 |
+
#ml Stages/2020
|
28 |
+
#ml GCC/9.3.0
|
29 |
+
#ml OpenMPI/4.1.0rc1
|
30 |
+
#ml CUDA/11.0
|
31 |
+
#ml cuDNN/8.0.2.39-CUDA-11.0
|
32 |
+
#ml NCCL/2.8.3-1-CUDA-11.0
|
33 |
+
#ml PyTorch
|
34 |
+
#ml Horovod/0.20.3-Python-3.8.5
|
35 |
+
#ml scikit
|
36 |
+
|
37 |
+
#ml Stages/2021
|
38 |
+
#ml GCC
|
39 |
+
#ml OpenMPI
|
40 |
+
#ml CUDA
|
41 |
+
#ml cuDNN
|
42 |
+
#ml NCCL
|
43 |
+
#ml PyTorch
|
44 |
+
#ml Horovod
|
45 |
+
#ml scikit
|
46 |
+
|
47 |
+
ml purge
|
48 |
+
ml use $OTHERSTAGES
|
49 |
+
ml Stages/2020
|
50 |
+
ml GCC/10.3.0
|
51 |
+
ml OpenMPI/4.1.1
|
52 |
+
ml Horovod/0.23.0-Python-3.8.5
|
53 |
+
ml scikit
|
54 |
+
source /p/project/covidnetx/environments/juwels_booster/bin/activate
|
55 |
+
fi
|
56 |
+
if [[ "$machine" == jusuf ]]; then
|
57 |
+
ml purge
|
58 |
+
ml use $OTHERSTAGES
|
59 |
+
ml Stages/2020
|
60 |
+
ml GCC/9.3.0
|
61 |
+
ml OpenMPI/4.1.0rc1
|
62 |
+
ml CUDA/11.0
|
63 |
+
ml cuDNN/8.0.2.39-CUDA-11.0
|
64 |
+
ml NCCL/2.8.3-1-CUDA-11.0
|
65 |
+
ml PyTorch
|
66 |
+
ml Horovod/0.20.3-Python-3.8.5
|
67 |
+
#ml scikit
|
68 |
+
source /p/project/covidnetx/environments/jusuf/bin/activate
|
69 |
+
fi
|
scripts/init_2022.sh
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
machine=$(cat /etc/FZJ/systemname)
|
2 |
+
if [[ "$machine" == jurecadc ]]; then
|
3 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
4 |
+
ml purge
|
5 |
+
ml use $OTHERSTAGES
|
6 |
+
ml Stages/2022
|
7 |
+
ml GCC/11.2.0
|
8 |
+
ml OpenMPI/4.1.2
|
9 |
+
ml CUDA/11.5
|
10 |
+
ml cuDNN/8.3.1.22-CUDA-11.5
|
11 |
+
ml NCCL/2.12.7-1-CUDA-11.5
|
12 |
+
ml PyTorch/1.11-CUDA-11.5
|
13 |
+
ml Horovod/0.24
|
14 |
+
ml torchvision/0.12.0
|
15 |
+
source /p/project/covidnetx/environments/jureca_2022/bin/activate
|
16 |
+
fi
|
17 |
+
if [[ "$machine" == juwelsbooster ]]; then
|
18 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
19 |
+
ml purge
|
20 |
+
ml use $OTHERSTAGES
|
21 |
+
ml Stages/2022
|
22 |
+
ml GCC/11.2.0
|
23 |
+
ml OpenMPI/4.1.2
|
24 |
+
ml CUDA/11.5
|
25 |
+
ml cuDNN/8.3.1.22-CUDA-11.5
|
26 |
+
ml NCCL/2.12.7-1-CUDA-11.5
|
27 |
+
ml PyTorch/1.11-CUDA-11.5
|
28 |
+
ml Horovod/0.24
|
29 |
+
ml torchvision/0.12.0
|
30 |
+
source /p/project/covidnetx/environments/juwels_booster_2022/bin/activate
|
31 |
+
fi
|
32 |
+
if [[ "$machine" == jusuf ]]; then
|
33 |
+
echo not supported
|
34 |
+
fi
|
scripts/run_jurecadc_ddp.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -x
|
2 |
+
#SBATCH --account=zam
|
3 |
+
#SBATCH --nodes=1
|
4 |
+
#SBATCH --ntasks-per-node=4
|
5 |
+
#SBATCH --cpus-per-task=24
|
6 |
+
#SBATCH --time=06:00:00
|
7 |
+
#SBATCH --gres=gpu:4
|
8 |
+
#SBATCH --partition=dc-gpu
|
9 |
+
source set_torch_distributed_vars.sh
|
10 |
+
#source scripts/init_2022.sh
|
11 |
+
#source scripts/init_2020.sh
|
12 |
+
source scripts/init.sh
|
13 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
14 |
+
echo "Job id: $SLURM_JOB_ID"
|
15 |
+
export TOKENIZERS_PARALLELISM=false
|
16 |
+
export NCCL_ASYNC_ERROR_HANDLING=1
|
17 |
+
srun python -u $*
|
scripts/run_jusuf_ddp.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -x
|
2 |
+
#SBATCH --account=zam
|
3 |
+
#SBATCH --nodes=1
|
4 |
+
#SBATCH --ntasks-per-node=1
|
5 |
+
#SBATCH --cpus-per-task=24
|
6 |
+
#SBATCH --time=06:00:00
|
7 |
+
#SBATCH --gres=gpu:1
|
8 |
+
#SBATCH --partition=gpus
|
9 |
+
source set_torch_distributed_vars.sh
|
10 |
+
source scripts/init.sh
|
11 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
12 |
+
echo "Job id: $SLURM_JOB_ID"
|
13 |
+
export TOKENIZERS_PARALLELISM=false
|
14 |
+
srun python -u $*
|
scripts/run_juwelsbooster_ddp.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -x
|
2 |
+
#SBATCH --account=covidnetx
|
3 |
+
#SBATCH --nodes=4
|
4 |
+
#SBATCH --ntasks-per-node=4
|
5 |
+
#SBATCH --cpus-per-task=24
|
6 |
+
#SBATCH --time=06:00:00
|
7 |
+
#SBATCH --gres=gpu:4
|
8 |
+
#SBATCH --partition=booster
|
9 |
+
source set_torch_distributed_vars.sh
|
10 |
+
#source scripts/init_2022.sh
|
11 |
+
#source scripts/init_2020.sh
|
12 |
+
source scripts/init.sh
|
13 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
14 |
+
echo "Job id: $SLURM_JOB_ID"
|
15 |
+
export TOKENIZERS_PARALLELISM=false
|
16 |
+
export NCCL_ASYNC_ERROR_HANDLING=1
|
17 |
+
srun python -u $*
|
t5.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transformers
|
3 |
+
from transformers import T5Tokenizer, T5EncoderModel, T5Config
|
4 |
+
|
5 |
+
transformers.logging.set_verbosity_error()
|
6 |
+
|
7 |
+
def exists(val):
|
8 |
+
return val is not None
|
9 |
+
|
10 |
+
# config
|
11 |
+
|
12 |
+
MAX_LENGTH = 256
|
13 |
+
|
14 |
+
DEFAULT_T5_NAME = 'google/t5-v1_1-base'
|
15 |
+
|
16 |
+
T5_CONFIGS = {}
|
17 |
+
|
18 |
+
# singleton globals
|
19 |
+
|
20 |
+
def get_tokenizer(name):
|
21 |
+
tokenizer = T5Tokenizer.from_pretrained(name)
|
22 |
+
return tokenizer
|
23 |
+
|
24 |
+
def get_model(name):
|
25 |
+
model = T5EncoderModel.from_pretrained(name)
|
26 |
+
return model
|
27 |
+
|
28 |
+
def get_model_and_tokenizer(name):
|
29 |
+
global T5_CONFIGS
|
30 |
+
|
31 |
+
if name not in T5_CONFIGS:
|
32 |
+
T5_CONFIGS[name] = dict()
|
33 |
+
if "model" not in T5_CONFIGS[name]:
|
34 |
+
T5_CONFIGS[name]["model"] = get_model(name)
|
35 |
+
if "tokenizer" not in T5_CONFIGS[name]:
|
36 |
+
T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
|
37 |
+
|
38 |
+
return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
|
39 |
+
|
40 |
+
def get_encoded_dim(name):
|
41 |
+
if name not in T5_CONFIGS:
|
42 |
+
# avoids loading the model if we only want to get the dim
|
43 |
+
config = T5Config.from_pretrained(name)
|
44 |
+
T5_CONFIGS[name] = dict(config=config)
|
45 |
+
elif "config" in T5_CONFIGS[name]:
|
46 |
+
config = T5_CONFIGS[name]["config"]
|
47 |
+
elif "model" in T5_CONFIGS[name]:
|
48 |
+
config = T5_CONFIGS[name]["model"].config
|
49 |
+
else:
|
50 |
+
assert False
|
51 |
+
return config.d_model
|
52 |
+
|
53 |
+
class T5Encoder(torch.nn.Module):
|
54 |
+
|
55 |
+
def __init__(self, name=DEFAULT_T5_NAME, max_length=MAX_LENGTH, padding='longest', masked_mean=False):
|
56 |
+
super().__init__()
|
57 |
+
self.name = name
|
58 |
+
self.t5, self.tokenizer = get_model_and_tokenizer(name)
|
59 |
+
self.max_length = max_length
|
60 |
+
self.output_size = get_encoded_dim(name)
|
61 |
+
self.padding = padding
|
62 |
+
self.masked_mean = masked_mean
|
63 |
+
|
64 |
+
def forward(self, x, return_only_pooled=True):
|
65 |
+
encoded = self.tokenizer.batch_encode_plus(
|
66 |
+
x,
|
67 |
+
return_tensors = "pt",
|
68 |
+
padding = self.padding,
|
69 |
+
max_length = self.max_length,
|
70 |
+
truncation = True
|
71 |
+
)
|
72 |
+
device = next(self.t5.parameters()).device
|
73 |
+
input_ids = encoded.input_ids.to(device)
|
74 |
+
attn_mask = encoded.attention_mask.to(device).bool()
|
75 |
+
output = self.t5(input_ids = input_ids, attention_mask = attn_mask)
|
76 |
+
encoded_text = output.last_hidden_state.detach()
|
77 |
+
# return encoded_text[:, 0]
|
78 |
+
# print(input_ids)
|
79 |
+
# print(attn_mask)
|
80 |
+
#if self.masked_mean:
|
81 |
+
pooled = masked_mean(encoded_text, dim=1, mask=attn_mask)
|
82 |
+
if return_only_pooled:
|
83 |
+
return pooled
|
84 |
+
else:
|
85 |
+
return pooled, encoded_text, attn_mask
|
86 |
+
#else:
|
87 |
+
# return encoded_text.mean(dim=1)
|
88 |
+
|
89 |
+
|
90 |
+
from einops import rearrange
|
91 |
+
def masked_mean(t, *, dim, mask = None):
|
92 |
+
if not exists(mask):
|
93 |
+
return t.mean(dim = dim)
|
94 |
+
|
95 |
+
denom = mask.sum(dim = dim, keepdim = True)
|
96 |
+
mask = rearrange(mask, 'b n -> b n 1')
|
97 |
+
masked_t = t.masked_fill(~mask, 0.)
|
98 |
+
|
99 |
+
return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
|
test_ddgan.py
CHANGED
@@ -7,12 +7,12 @@
|
|
7 |
import argparse
|
8 |
import torch
|
9 |
import numpy as np
|
10 |
-
|
11 |
import os
|
12 |
-
|
13 |
import torchvision
|
14 |
from score_sde.models.ncsnpp_generator_adagn import NCSNpp
|
15 |
-
|
16 |
|
17 |
#%% Diffusion coefficients
|
18 |
def var_func_vp(t, beta_min, beta_max):
|
@@ -112,7 +112,7 @@ def sample_posterior(coefficients, x_0,x_t, t):
|
|
112 |
|
113 |
return sample_x_pos
|
114 |
|
115 |
-
def sample_from_model(coefficients, generator, n_time, x_init, T, opt):
|
116 |
x = x_init
|
117 |
with torch.no_grad():
|
118 |
for i in reversed(range(n_time)):
|
@@ -120,17 +120,70 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt):
|
|
120 |
|
121 |
t_time = t
|
122 |
latent_z = torch.randn(x.size(0), opt.nz, device=x.device)#.to(x.device)
|
123 |
-
x_0 = generator(x, t_time, latent_z)
|
124 |
x_new = sample_posterior(coefficients, x_0, x, t)
|
125 |
x = x_new.detach()
|
126 |
|
127 |
return x
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
#%%
|
130 |
def sample_and_test(args):
|
131 |
-
torch.manual_seed(
|
132 |
device = 'cuda:0'
|
133 |
-
|
|
|
|
|
|
|
134 |
if args.dataset == 'cifar10':
|
135 |
real_img_dir = 'pytorch_fid/cifar10_train_stat.npy'
|
136 |
elif args.dataset == 'celeba_256':
|
@@ -157,7 +210,6 @@ def sample_and_test(args):
|
|
157 |
|
158 |
pos_coeff = Posterior_Coefficients(args, device)
|
159 |
|
160 |
-
iters_needed = 50000 //args.batch_size
|
161 |
|
162 |
save_dir = "./generated_samples/{}".format(args.dataset)
|
163 |
|
@@ -165,25 +217,90 @@ def sample_and_test(args):
|
|
165 |
os.makedirs(save_dir)
|
166 |
|
167 |
if args.compute_fid:
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
with torch.no_grad():
|
170 |
-
|
171 |
-
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
173 |
fake_sample = to_range_0_1(fake_sample)
|
|
|
174 |
for j, x in enumerate(fake_sample):
|
175 |
index = i * args.batch_size + j
|
176 |
torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
print('FID = {}'.format(fid))
|
184 |
else:
|
|
|
185 |
x_t_1 = torch.randn(args.batch_size, args.num_channels,args.image_size, args.image_size).to(device)
|
186 |
-
|
|
|
|
|
|
|
187 |
fake_sample = to_range_0_1(fake_sample)
|
188 |
torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
|
189 |
|
@@ -198,6 +315,13 @@ if __name__ == '__main__':
|
|
198 |
parser.add_argument('--compute_fid', action='store_true', default=False,
|
199 |
help='whether or not compute FID')
|
200 |
parser.add_argument('--epoch_id', type=int,default=1000)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
parser.add_argument('--num_channels', type=int, default=3,
|
202 |
help='channel of image')
|
203 |
parser.add_argument('--centered', action='store_false', default=True,
|
@@ -262,6 +386,8 @@ if __name__ == '__main__':
|
|
262 |
parser.add_argument('--z_emb_dim', type=int, default=256)
|
263 |
parser.add_argument('--t_emb_dim', type=int, default=256)
|
264 |
parser.add_argument('--batch_size', type=int, default=200, help='sample generating batch size')
|
|
|
|
|
265 |
|
266 |
|
267 |
|
@@ -272,4 +398,4 @@ if __name__ == '__main__':
|
|
272 |
sample_and_test(args)
|
273 |
|
274 |
|
275 |
-
|
|
|
7 |
import argparse
|
8 |
import torch
|
9 |
import numpy as np
|
10 |
+
import time
|
11 |
import os
|
12 |
+
import json
|
13 |
import torchvision
|
14 |
from score_sde.models.ncsnpp_generator_adagn import NCSNpp
|
15 |
+
import t5
|
16 |
|
17 |
#%% Diffusion coefficients
|
18 |
def var_func_vp(t, beta_min, beta_max):
|
|
|
112 |
|
113 |
return sample_x_pos
|
114 |
|
115 |
+
def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None):
|
116 |
x = x_init
|
117 |
with torch.no_grad():
|
118 |
for i in reversed(range(n_time)):
|
|
|
120 |
|
121 |
t_time = t
|
122 |
latent_z = torch.randn(x.size(0), opt.nz, device=x.device)#.to(x.device)
|
123 |
+
x_0 = generator(x, t_time, latent_z, cond=cond)
|
124 |
x_new = sample_posterior(coefficients, x_0, x, t)
|
125 |
x = x_new.detach()
|
126 |
|
127 |
return x
|
128 |
|
129 |
+
|
130 |
+
def sample_from_model_classifier_free_guidance(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0):
|
131 |
+
x = x_init
|
132 |
+
null = text_encoder([""] * len(x_init), return_only_pooled=False)
|
133 |
+
with torch.no_grad():
|
134 |
+
for i in reversed(range(n_time)):
|
135 |
+
t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)
|
136 |
+
|
137 |
+
t_time = t
|
138 |
+
latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
139 |
+
|
140 |
+
x_0_uncond = generator(x, t_time, latent_z, cond=null)
|
141 |
+
x_0_cond = generator(x, t_time, latent_z, cond=cond)
|
142 |
+
|
143 |
+
eps_uncond = (x - torch.sqrt(coefficients.alphas_cumprod[i]) * x_0_uncond) / torch.sqrt(1 - coefficients.alphas_cumprod[i])
|
144 |
+
eps_cond = (x - torch.sqrt(coefficients.alphas_cumprod[i]) * x_0_cond) / torch.sqrt(1 - coefficients.alphas_cumprod[i])
|
145 |
+
|
146 |
+
# eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
|
147 |
+
eps = eps_uncond * (1 - guidance_scale) + eps_cond * guidance_scale
|
148 |
+
x_0 = (1/torch.sqrt(coefficients.alphas_cumprod[i])) * (x - torch.sqrt(1 - coefficients.alphas_cumprod[i]) * eps)
|
149 |
+
|
150 |
+
|
151 |
+
# Dynamic thresholding
|
152 |
+
q = args.dynamic_thresholding_percentile
|
153 |
+
print("Before", x_0.min(), x_0.max())
|
154 |
+
if q:
|
155 |
+
shape = x_0.shape
|
156 |
+
x_0_v = x_0.view(shape[0], -1)
|
157 |
+
d = torch.quantile(torch.abs(x_0_v), q, dim=1, keepdim=True)
|
158 |
+
d.clamp_(min=1)
|
159 |
+
x_0_v = x_0_v.clamp(-d, d) / d
|
160 |
+
x_0 = x_0_v.view(shape)
|
161 |
+
print("After", x_0.min(), x_0.max())
|
162 |
+
|
163 |
+
x_new = sample_posterior(coefficients, x_0, x, t)
|
164 |
+
|
165 |
+
# Dynamic thresholding
|
166 |
+
# q = args.dynamic_thresholding_percentile
|
167 |
+
# shape = x_new.shape
|
168 |
+
# x_new_v = x_new.view(shape[0], -1)
|
169 |
+
# d = torch.quantile(torch.abs(x_new_v), q, dim=1, keepdim=True)
|
170 |
+
# d = torch.maximum(d, torch.ones_like(d))
|
171 |
+
# d.clamp_(min = 1.)
|
172 |
+
# x_new_v = torch.clamp(x_new_v, -d, d) / d
|
173 |
+
# x_new = x_new_v.view(shape)
|
174 |
+
x = x_new.detach()
|
175 |
+
|
176 |
+
return x
|
177 |
+
|
178 |
+
|
179 |
#%%
|
180 |
def sample_and_test(args):
|
181 |
+
torch.manual_seed(args.seed)
|
182 |
device = 'cuda:0'
|
183 |
+
text_encoder = t5.T5Encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
|
184 |
+
args.cond_size = text_encoder.output_size
|
185 |
+
# cond = text_encoder([str(yi%10) for yi in range(args.batch_size)])
|
186 |
+
|
187 |
if args.dataset == 'cifar10':
|
188 |
real_img_dir = 'pytorch_fid/cifar10_train_stat.npy'
|
189 |
elif args.dataset == 'celeba_256':
|
|
|
210 |
|
211 |
pos_coeff = Posterior_Coefficients(args, device)
|
212 |
|
|
|
213 |
|
214 |
save_dir = "./generated_samples/{}".format(args.dataset)
|
215 |
|
|
|
217 |
os.makedirs(save_dir)
|
218 |
|
219 |
if args.compute_fid:
|
220 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
221 |
+
from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
|
222 |
+
from pytorch_fid.inception import InceptionV3
|
223 |
+
|
224 |
+
texts = open(args.cond_text).readlines()
|
225 |
+
#iters_needed = len(texts) // args.batch_size
|
226 |
+
#texts = list(map(lambda s:s.strip(), texts))
|
227 |
+
#ntimes = max(30000 // len(texts), 1)
|
228 |
+
#texts = texts * ntimes
|
229 |
+
print("Text size:", len(texts))
|
230 |
+
#print("Iters:", iters_needed)
|
231 |
+
i = 0
|
232 |
+
dims = 2048
|
233 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
234 |
+
inceptionv3 = InceptionV3([block_idx]).to(device)
|
235 |
+
|
236 |
+
if not args.real_img_dir.endswith("npz"):
|
237 |
+
real_mu, real_sigma = compute_statistics_of_path(
|
238 |
+
args.real_img_dir, inceptionv3, args.batch_size, dims, device,
|
239 |
+
resize=args.image_size,
|
240 |
+
)
|
241 |
+
np.savez("inception_statistics.npz", mu=real_mu, sigma=real_sigma)
|
242 |
+
else:
|
243 |
+
stats = np.load(args.real_img_dir)
|
244 |
+
real_mu = stats['mu']
|
245 |
+
real_sigma = stats['sigma']
|
246 |
+
|
247 |
+
fake_features = []
|
248 |
+
for b in range(0, len(texts), args.batch_size):
|
249 |
+
text = texts[b:b+args.batch_size]
|
250 |
with torch.no_grad():
|
251 |
+
cond = text_encoder(text, return_only_pooled=False)
|
252 |
+
bs = len(text)
|
253 |
+
t0 = time.time()
|
254 |
+
x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
|
255 |
+
if args.guidance_scale:
|
256 |
+
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
257 |
+
else:
|
258 |
+
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
|
259 |
fake_sample = to_range_0_1(fake_sample)
|
260 |
+
"""
|
261 |
for j, x in enumerate(fake_sample):
|
262 |
index = i * args.batch_size + j
|
263 |
torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
|
264 |
+
"""
|
265 |
+
with torch.no_grad():
|
266 |
+
pred = inceptionv3(fake_sample)[0]
|
267 |
+
# If model output is not scalar, apply global spatial average pooling.
|
268 |
+
# This happens if you choose a dimensionality not equal 2048.
|
269 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
270 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
271 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
272 |
+
fake_features.append(pred)
|
273 |
+
if i % 10 == 0:
|
274 |
+
print('generating batch ', i, time.time() - t0)
|
275 |
+
"""
|
276 |
+
if i % 10 == 0:
|
277 |
+
ff = np.concatenate(fake_features)
|
278 |
+
fake_mu = np.mean(ff, axis=0)
|
279 |
+
fake_sigma = np.cov(ff, rowvar=False)
|
280 |
+
fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
|
281 |
+
print("FID", fid)
|
282 |
+
"""
|
283 |
+
i += 1
|
284 |
+
|
285 |
+
fake_features = np.concatenate(fake_features)
|
286 |
+
fake_mu = np.mean(fake_features, axis=0)
|
287 |
+
fake_sigma = np.cov(fake_features, rowvar=False)
|
288 |
+
fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
|
289 |
+
dest = './saved_info/dd_gan/{}/{}/fid_{}.json'.format(args.dataset, args.exp, args.epoch_id)
|
290 |
+
results = {
|
291 |
+
"fid": fid,
|
292 |
+
}
|
293 |
+
results.update(vars(args))
|
294 |
+
with open(dest, "w") as fd:
|
295 |
+
json.dump(results, fd)
|
296 |
print('FID = {}'.format(fid))
|
297 |
else:
|
298 |
+
cond = text_encoder([args.cond_text] * args.batch_size, return_only_pooled=False)
|
299 |
x_t_1 = torch.randn(args.batch_size, args.num_channels,args.image_size, args.image_size).to(device)
|
300 |
+
if args.guidance_scale:
|
301 |
+
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
302 |
+
else:
|
303 |
+
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
|
304 |
fake_sample = to_range_0_1(fake_sample)
|
305 |
torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
|
306 |
|
|
|
315 |
parser.add_argument('--compute_fid', action='store_true', default=False,
|
316 |
help='whether or not compute FID')
|
317 |
parser.add_argument('--epoch_id', type=int,default=1000)
|
318 |
+
parser.add_argument('--guidance_scale', type=float,default=0)
|
319 |
+
parser.add_argument('--dynamic_thresholding_percentile', type=float,default=0)
|
320 |
+
parser.add_argument('--cond_text', type=str,default="0")
|
321 |
+
|
322 |
+
parser.add_argument('--cross_attention', action='store_true',default=False)
|
323 |
+
|
324 |
+
|
325 |
parser.add_argument('--num_channels', type=int, default=3,
|
326 |
help='channel of image')
|
327 |
parser.add_argument('--centered', action='store_false', default=True,
|
|
|
386 |
parser.add_argument('--z_emb_dim', type=int, default=256)
|
387 |
parser.add_argument('--t_emb_dim', type=int, default=256)
|
388 |
parser.add_argument('--batch_size', type=int, default=200, help='sample generating batch size')
|
389 |
+
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
390 |
+
parser.add_argument('--masked_mean', action='store_true',default=False)
|
391 |
|
392 |
|
393 |
|
|
|
398 |
sample_and_test(args)
|
399 |
|
400 |
|
401 |
+
|
train_ddgan.py
CHANGED
@@ -18,7 +18,7 @@ import torch.optim as optim
|
|
18 |
import torchvision
|
19 |
|
20 |
import torchvision.transforms as transforms
|
21 |
-
from torchvision.datasets import CIFAR10
|
22 |
from datasets_prep.lsun import LSUN
|
23 |
from datasets_prep.stackmnist_data import StackedMNIST, _data_transforms_stacked_mnist
|
24 |
from datasets_prep.lmdb_datasets import LMDBDataset
|
@@ -27,6 +27,11 @@ from datasets_prep.lmdb_datasets import LMDBDataset
|
|
27 |
from torch.multiprocessing import Process
|
28 |
import torch.distributed as dist
|
29 |
import shutil
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
def copy_source(file, output_dir):
|
32 |
shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))
|
@@ -172,7 +177,7 @@ def sample_posterior(coefficients, x_0,x_t, t):
|
|
172 |
|
173 |
return sample_x_pos
|
174 |
|
175 |
-
def sample_from_model(coefficients, generator, n_time, x_init, T, opt):
|
176 |
x = x_init
|
177 |
with torch.no_grad():
|
178 |
for i in reversed(range(n_time)):
|
@@ -180,13 +185,15 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt):
|
|
180 |
|
181 |
t_time = t
|
182 |
latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
183 |
-
x_0 = generator(x, t_time, latent_z)
|
184 |
x_new = sample_posterior(coefficients, x_0, x, t)
|
185 |
x = x_new.detach()
|
186 |
|
187 |
return x
|
188 |
|
189 |
-
|
|
|
|
|
190 |
def train(rank, gpu, args):
|
191 |
from score_sde.models.discriminator import Discriminator_small, Discriminator_large
|
192 |
from score_sde.models.ncsnpp_generator_adagn import NCSNpp
|
@@ -236,37 +243,81 @@ def train(rank, gpu, args):
|
|
236 |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
237 |
])
|
238 |
dataset = LMDBDataset(root='/datasets/celeba-lmdb/', name='celeba', train=True, transform=train_transform)
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
|
|
247 |
shuffle=False,
|
248 |
num_workers=4,
|
|
|
249 |
pin_memory=True,
|
250 |
-
sampler=train_sampler,
|
251 |
-
|
252 |
-
|
253 |
netG = NCSNpp(args).to(device)
|
|
|
|
|
|
|
|
|
254 |
|
255 |
|
256 |
if args.dataset == 'cifar10' or args.dataset == 'stackmnist':
|
257 |
netD = Discriminator_small(nc = 2*args.num_channels, ngf = args.ngf,
|
258 |
t_emb_dim = args.t_emb_dim,
|
|
|
259 |
act=nn.LeakyReLU(0.2)).to(device)
|
260 |
else:
|
261 |
netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
|
262 |
t_emb_dim = args.t_emb_dim,
|
|
|
263 |
act=nn.LeakyReLU(0.2)).to(device)
|
264 |
|
265 |
broadcast_params(netG.parameters())
|
266 |
broadcast_params(netD.parameters())
|
267 |
|
268 |
optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
|
269 |
-
|
270 |
optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
|
271 |
|
272 |
if args.use_ema:
|
@@ -297,9 +348,9 @@ def train(rank, gpu, args):
|
|
297 |
pos_coeff = Posterior_Coefficients(args, device)
|
298 |
T = get_time_schedule(args, device)
|
299 |
|
300 |
-
|
301 |
-
|
302 |
-
checkpoint = torch.load(checkpoint_file, map_location=
|
303 |
init_epoch = checkpoint['epoch']
|
304 |
epoch = init_epoch
|
305 |
netG.load_state_dict(checkpoint['netG_dict'])
|
@@ -319,9 +370,22 @@ def train(rank, gpu, args):
|
|
319 |
|
320 |
|
321 |
for epoch in range(init_epoch, args.num_epoch+1):
|
322 |
-
|
|
|
|
|
|
|
323 |
|
324 |
for iteration, (x, y) in enumerate(data_loader):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
for p in netD.parameters():
|
326 |
p.requires_grad = True
|
327 |
|
@@ -339,7 +403,7 @@ def train(rank, gpu, args):
|
|
339 |
|
340 |
|
341 |
# train with real
|
342 |
-
D_real = netD(x_t, t, x_tp1.detach()).view(-1)
|
343 |
|
344 |
errD_real = F.softplus(-D_real)
|
345 |
errD_real = errD_real.mean()
|
@@ -375,10 +439,10 @@ def train(rank, gpu, args):
|
|
375 |
latent_z = torch.randn(batch_size, nz, device=device)
|
376 |
|
377 |
|
378 |
-
x_0_predict = netG(x_tp1.detach(), t, latent_z)
|
379 |
x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
|
380 |
|
381 |
-
output = netD(x_pos_sample, t, x_tp1.detach()).view(-1)
|
382 |
|
383 |
|
384 |
errD_fake = F.softplus(output)
|
@@ -407,11 +471,10 @@ def train(rank, gpu, args):
|
|
407 |
|
408 |
|
409 |
|
410 |
-
|
411 |
-
x_0_predict = netG(x_tp1.detach(), t, latent_z)
|
412 |
x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
|
413 |
|
414 |
-
output = netD(x_pos_sample, t, x_tp1.detach()).view(-1)
|
415 |
|
416 |
|
417 |
errG = F.softplus(-output)
|
@@ -426,7 +489,27 @@ def train(rank, gpu, args):
|
|
426 |
if iteration % 100 == 0:
|
427 |
if rank == 0:
|
428 |
print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
|
429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
if not args.no_lr_decay:
|
431 |
|
432 |
schedulerG.step()
|
@@ -437,7 +520,7 @@ def train(rank, gpu, args):
|
|
437 |
torchvision.utils.save_image(x_pos_sample, os.path.join(exp_path, 'xpos_epoch_{}.png'.format(epoch)), normalize=True)
|
438 |
|
439 |
x_t_1 = torch.randn_like(real_data)
|
440 |
-
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args)
|
441 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)), normalize=True)
|
442 |
|
443 |
if args.save_content:
|
@@ -449,6 +532,7 @@ def train(rank, gpu, args):
|
|
449 |
'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
|
450 |
|
451 |
torch.save(content, os.path.join(exp_path, 'content.pth'))
|
|
|
452 |
|
453 |
if epoch % args.save_ckpt_every == 0:
|
454 |
if args.use_ema:
|
@@ -462,11 +546,19 @@ def train(rank, gpu, args):
|
|
462 |
|
463 |
def init_processes(rank, size, fn, args):
|
464 |
""" Initialize the distributed environment. """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
os.environ['MASTER_ADDR'] = args.master_address
|
466 |
-
os.environ['MASTER_PORT'] =
|
467 |
torch.cuda.set_device(args.local_rank)
|
468 |
gpu = args.local_rank
|
469 |
-
dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=
|
470 |
fn(rank, gpu, args)
|
471 |
dist.barrier()
|
472 |
cleanup()
|
@@ -480,7 +572,10 @@ if __name__ == '__main__':
|
|
480 |
help='seed used for initialization')
|
481 |
|
482 |
parser.add_argument('--resume', action='store_true',default=False)
|
483 |
-
|
|
|
|
|
|
|
484 |
parser.add_argument('--image_size', type=int, default=32,
|
485 |
help='size of image')
|
486 |
parser.add_argument('--num_channels', type=int, default=3,
|
@@ -492,7 +587,7 @@ if __name__ == '__main__':
|
|
492 |
help='beta_min for diffusion')
|
493 |
parser.add_argument('--beta_max', type=float, default=20.,
|
494 |
help='beta_max for diffusion')
|
495 |
-
|
496 |
|
497 |
parser.add_argument('--num_channels_dae', type=int, default=128,
|
498 |
help='number of initial channels in denosing model')
|
@@ -534,6 +629,7 @@ if __name__ == '__main__':
|
|
534 |
#geenrator and training
|
535 |
parser.add_argument('--exp', default='experiment_cifar_default', help='name of experiment')
|
536 |
parser.add_argument('--dataset', default='cifar10', help='name of dataset')
|
|
|
537 |
parser.add_argument('--nz', type=int, default=100)
|
538 |
parser.add_argument('--num_timesteps', type=int, default=4)
|
539 |
|
@@ -577,26 +673,28 @@ if __name__ == '__main__':
|
|
577 |
|
578 |
|
579 |
args = parser.parse_args()
|
580 |
-
args.world_size = args.num_proc_node * args.num_process_per_node
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
|
|
|
|
599 |
|
600 |
-
init_processes(0, size, train, args)
|
601 |
|
602 |
-
|
|
|
18 |
import torchvision
|
19 |
|
20 |
import torchvision.transforms as transforms
|
21 |
+
from torchvision.datasets import CIFAR10, ImageFolder
|
22 |
from datasets_prep.lsun import LSUN
|
23 |
from datasets_prep.stackmnist_data import StackedMNIST, _data_transforms_stacked_mnist
|
24 |
from datasets_prep.lmdb_datasets import LMDBDataset
|
|
|
27 |
from torch.multiprocessing import Process
|
28 |
import torch.distributed as dist
|
29 |
import shutil
|
30 |
+
import logging
|
31 |
+
import t5
|
32 |
+
def log_and_continue(exn):
|
33 |
+
logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
|
34 |
+
return True
|
35 |
|
36 |
def copy_source(file, output_dir):
|
37 |
shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))
|
|
|
177 |
|
178 |
return sample_x_pos
|
179 |
|
180 |
+
def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None):
|
181 |
x = x_init
|
182 |
with torch.no_grad():
|
183 |
for i in reversed(range(n_time)):
|
|
|
185 |
|
186 |
t_time = t
|
187 |
latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
188 |
+
x_0 = generator(x, t_time, latent_z, cond=cond)
|
189 |
x_new = sample_posterior(coefficients, x_0, x, t)
|
190 |
x = x_new.detach()
|
191 |
|
192 |
return x
|
193 |
|
194 |
+
|
195 |
+
from utils import ResampledShards2
|
196 |
+
|
197 |
def train(rank, gpu, args):
|
198 |
from score_sde.models.discriminator import Discriminator_small, Discriminator_large
|
199 |
from score_sde.models.ncsnpp_generator_adagn import NCSNpp
|
|
|
243 |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
244 |
])
|
245 |
dataset = LMDBDataset(root='/datasets/celeba-lmdb/', name='celeba', train=True, transform=train_transform)
|
246 |
+
elif args.dataset == "image_folder":
|
247 |
+
train_transform = transforms.Compose([
|
248 |
+
transforms.Resize(args.image_size),
|
249 |
+
transforms.CenterCrop(args.image_size),
|
250 |
+
# transforms.RandomHorizontalFlip(),
|
251 |
+
transforms.ToTensor(),
|
252 |
+
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
253 |
+
])
|
254 |
+
dataset = ImageFolder(root=args.dataset_root, transform=train_transform)
|
255 |
+
elif args.dataset == 'wds':
|
256 |
+
import webdataset as wds
|
257 |
+
train_transform = transforms.Compose([
|
258 |
+
transforms.Resize(args.image_size),
|
259 |
+
transforms.CenterCrop(args.image_size),
|
260 |
+
# transforms.RandomHorizontalFlip(),
|
261 |
+
transforms.ToTensor(),
|
262 |
+
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
263 |
+
])
|
264 |
+
# pipeline = [wds.SimpleShardList(args.dataset_root)]
|
265 |
+
pipeline = [ResampledShards2(args.dataset_root)]
|
266 |
+
pipeline.extend([
|
267 |
+
wds.split_by_node,
|
268 |
+
wds.split_by_worker,
|
269 |
+
wds.tarfile_to_samples(handler=log_and_continue),
|
270 |
+
])
|
271 |
+
pipeline.extend([
|
272 |
+
wds.decode("pilrgb", handler=log_and_continue),
|
273 |
+
wds.rename(image="jpg;png"),
|
274 |
+
wds.map_dict(image=train_transform),
|
275 |
+
wds.to_tuple("image","txt"),
|
276 |
+
wds.batched(batch_size, partial=False),
|
277 |
+
])
|
278 |
+
dataset = wds.DataPipeline(*pipeline)
|
279 |
+
data_loader = wds.WebLoader(
|
280 |
+
dataset,
|
281 |
+
batch_size=None,
|
282 |
+
shuffle=False,
|
283 |
+
num_workers=8,
|
284 |
+
)
|
285 |
|
286 |
+
if args.dataset != "wds":
|
287 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset,
|
288 |
+
num_replicas=args.world_size,
|
289 |
+
rank=rank)
|
290 |
+
data_loader = torch.utils.data.DataLoader(dataset,
|
291 |
+
batch_size=batch_size,
|
292 |
shuffle=False,
|
293 |
num_workers=4,
|
294 |
+
drop_last=True,
|
295 |
pin_memory=True,
|
296 |
+
sampler=train_sampler,)
|
297 |
+
text_encoder = t5.T5Encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
|
298 |
+
args.cond_size = text_encoder.output_size
|
299 |
netG = NCSNpp(args).to(device)
|
300 |
+
nb_params = 0
|
301 |
+
for param in netG.parameters():
|
302 |
+
nb_params += param.flatten().shape[0]
|
303 |
+
print("Number of generator parameters:", nb_params)
|
304 |
|
305 |
|
306 |
if args.dataset == 'cifar10' or args.dataset == 'stackmnist':
|
307 |
netD = Discriminator_small(nc = 2*args.num_channels, ngf = args.ngf,
|
308 |
t_emb_dim = args.t_emb_dim,
|
309 |
+
cond_size=text_encoder.output_size,
|
310 |
act=nn.LeakyReLU(0.2)).to(device)
|
311 |
else:
|
312 |
netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
|
313 |
t_emb_dim = args.t_emb_dim,
|
314 |
+
cond_size=text_encoder.output_size,
|
315 |
act=nn.LeakyReLU(0.2)).to(device)
|
316 |
|
317 |
broadcast_params(netG.parameters())
|
318 |
broadcast_params(netD.parameters())
|
319 |
|
320 |
optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
|
|
|
321 |
optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
|
322 |
|
323 |
if args.use_ema:
|
|
|
348 |
pos_coeff = Posterior_Coefficients(args, device)
|
349 |
T = get_time_schedule(args, device)
|
350 |
|
351 |
+
checkpoint_file = os.path.join(exp_path, 'content.pth')
|
352 |
+
if args.resume and os.path.exists(checkpoint_file):
|
353 |
+
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
354 |
init_epoch = checkpoint['epoch']
|
355 |
epoch = init_epoch
|
356 |
netG.load_state_dict(checkpoint['netG_dict'])
|
|
|
370 |
|
371 |
|
372 |
for epoch in range(init_epoch, args.num_epoch+1):
|
373 |
+
if args.dataset == "wds":
|
374 |
+
os.environ["WDS_EPOCH"] = str(epoch)
|
375 |
+
else:
|
376 |
+
train_sampler.set_epoch(epoch)
|
377 |
|
378 |
for iteration, (x, y) in enumerate(data_loader):
|
379 |
+
if args.dataset != "wds":
|
380 |
+
y = [str(yi) for yi in y.tolist()]
|
381 |
+
|
382 |
+
if args.classifier_free_guidance_proba:
|
383 |
+
u = (np.random.uniform(size=len(y)) <= args.classifier_free_guidance_proba).tolist()
|
384 |
+
y = ["" if ui else yi for yi,ui in zip(y, u)]
|
385 |
+
|
386 |
+
with torch.no_grad():
|
387 |
+
cond_pooled, cond, cond_mask = text_encoder(y, return_only_pooled=False)
|
388 |
+
|
389 |
for p in netD.parameters():
|
390 |
p.requires_grad = True
|
391 |
|
|
|
403 |
|
404 |
|
405 |
# train with real
|
406 |
+
D_real = netD(x_t, t, x_tp1.detach(), cond=cond_pooled).view(-1)
|
407 |
|
408 |
errD_real = F.softplus(-D_real)
|
409 |
errD_real = errD_real.mean()
|
|
|
439 |
latent_z = torch.randn(batch_size, nz, device=device)
|
440 |
|
441 |
|
442 |
+
x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
443 |
x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
|
444 |
|
445 |
+
output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_pooled).view(-1)
|
446 |
|
447 |
|
448 |
errD_fake = F.softplus(output)
|
|
|
471 |
|
472 |
|
473 |
|
474 |
+
x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
|
|
475 |
x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
|
476 |
|
477 |
+
output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_pooled).view(-1)
|
478 |
|
479 |
|
480 |
errG = F.softplus(-output)
|
|
|
489 |
if iteration % 100 == 0:
|
490 |
if rank == 0:
|
491 |
print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
|
492 |
+
if iteration % 1000 == 0:
|
493 |
+
x_t_1 = torch.randn_like(real_data)
|
494 |
+
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
|
495 |
+
if rank == 0:
|
496 |
+
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
|
497 |
+
if args.save_content:
|
498 |
+
print('Saving content.')
|
499 |
+
content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
|
500 |
+
'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
|
501 |
+
'schedulerG': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
|
502 |
+
'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
|
503 |
+
|
504 |
+
torch.save(content, os.path.join(exp_path, 'content.pth'))
|
505 |
+
torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
|
506 |
+
if args.use_ema:
|
507 |
+
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
508 |
+
|
509 |
+
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
510 |
+
if args.use_ema:
|
511 |
+
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
512 |
+
|
513 |
if not args.no_lr_decay:
|
514 |
|
515 |
schedulerG.step()
|
|
|
520 |
torchvision.utils.save_image(x_pos_sample, os.path.join(exp_path, 'xpos_epoch_{}.png'.format(epoch)), normalize=True)
|
521 |
|
522 |
x_t_1 = torch.randn_like(real_data)
|
523 |
+
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
|
524 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)), normalize=True)
|
525 |
|
526 |
if args.save_content:
|
|
|
532 |
'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
|
533 |
|
534 |
torch.save(content, os.path.join(exp_path, 'content.pth'))
|
535 |
+
torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
|
536 |
|
537 |
if epoch % args.save_ckpt_every == 0:
|
538 |
if args.use_ema:
|
|
|
546 |
|
547 |
def init_processes(rank, size, fn, args):
|
548 |
""" Initialize the distributed environment. """
|
549 |
+
|
550 |
+
import os
|
551 |
+
|
552 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
553 |
+
args.world_size = int(os.getenv("SLURM_NTASKS"))
|
554 |
+
args.local_rank = int(os.environ['SLURM_LOCALID'])
|
555 |
+
print(args.rank, args.world_size)
|
556 |
+
args.master_address = os.getenv("SLURM_LAUNCH_NODE_IPADDR")
|
557 |
os.environ['MASTER_ADDR'] = args.master_address
|
558 |
+
os.environ['MASTER_PORT'] = "12345"
|
559 |
torch.cuda.set_device(args.local_rank)
|
560 |
gpu = args.local_rank
|
561 |
+
dist.init_process_group(backend='nccl', init_method='env://', rank=rank, world_size=args.world_size)
|
562 |
fn(rank, gpu, args)
|
563 |
dist.barrier()
|
564 |
cleanup()
|
|
|
572 |
help='seed used for initialization')
|
573 |
|
574 |
parser.add_argument('--resume', action='store_true',default=False)
|
575 |
+
parser.add_argument('--masked_mean', action='store_true',default=False)
|
576 |
+
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
577 |
+
parser.add_argument('--cross_attention', action='store_true',default=False)
|
578 |
+
|
579 |
parser.add_argument('--image_size', type=int, default=32,
|
580 |
help='size of image')
|
581 |
parser.add_argument('--num_channels', type=int, default=3,
|
|
|
587 |
help='beta_min for diffusion')
|
588 |
parser.add_argument('--beta_max', type=float, default=20.,
|
589 |
help='beta_max for diffusion')
|
590 |
+
parser.add_argument('--classifier_free_guidance_proba', type=float, default=0.0)
|
591 |
|
592 |
parser.add_argument('--num_channels_dae', type=int, default=128,
|
593 |
help='number of initial channels in denosing model')
|
|
|
629 |
#geenrator and training
|
630 |
parser.add_argument('--exp', default='experiment_cifar_default', help='name of experiment')
|
631 |
parser.add_argument('--dataset', default='cifar10', help='name of dataset')
|
632 |
+
parser.add_argument('--dataset_root', default='', help='name of dataset')
|
633 |
parser.add_argument('--nz', type=int, default=100)
|
634 |
parser.add_argument('--num_timesteps', type=int, default=4)
|
635 |
|
|
|
673 |
|
674 |
|
675 |
args = parser.parse_args()
|
676 |
+
# args.world_size = args.num_proc_node * args.num_process_per_node
|
677 |
+
args.world_size = int(os.getenv("SLURM_NTASKS"))
|
678 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
679 |
+
# size = args.num_process_per_node
|
680 |
+
init_processes(args.rank, args.world_size, train, args)
|
681 |
+
# if size > 1:
|
682 |
+
# processes = []
|
683 |
+
# for rank in range(size):
|
684 |
+
# args.local_rank = rank
|
685 |
+
# global_rank = rank + args.node_rank * args.num_process_per_node
|
686 |
+
# global_size = args.num_proc_node * args.num_process_per_node
|
687 |
+
# args.global_rank = global_rank
|
688 |
+
# print('Node rank %d, local proc %d, global proc %d' % (args.node_rank, rank, global_rank))
|
689 |
+
# p = Process(target=init_processes, args=(global_rank, global_size, train, args))
|
690 |
+
# p.start()
|
691 |
+
# processes.append(p)
|
692 |
+
|
693 |
+
# for p in processes:
|
694 |
+
# p.join()
|
695 |
+
# else:
|
696 |
+
# print('starting in debug mode')
|
697 |
|
698 |
+
# init_processes(0, size, train, args)
|
699 |
|
700 |
+
|
utils.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
|
2 |
+
import braceexpand
|
3 |
+
import random
|
4 |
+
import sys
|
5 |
+
def pytorch_worker_seed():
|
6 |
+
"""get dataloader worker seed from pytorch"""
|
7 |
+
worker_info = get_worker_info()
|
8 |
+
if worker_info is not None:
|
9 |
+
# favour the seed already created for pytorch dataloader workers if it exists
|
10 |
+
return worker_info.seed
|
11 |
+
# fallback to wds rank based seed
|
12 |
+
return wds.utils.pytorch_worker_seed()
|
13 |
+
|
14 |
+
|
15 |
+
class SharedEpoch:
|
16 |
+
def __init__(self, epoch: int = 0):
|
17 |
+
self.shared_epoch = Value('i', epoch)
|
18 |
+
|
19 |
+
def set_value(self, epoch):
|
20 |
+
self.shared_epoch.value = epoch
|
21 |
+
|
22 |
+
def get_value(self):
|
23 |
+
return self.shared_epoch.value
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
class ResampledShards2(IterableDataset):
|
28 |
+
"""An iterable dataset yielding a list of urls."""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
urls,
|
33 |
+
nshards=sys.maxsize,
|
34 |
+
worker_seed=None,
|
35 |
+
deterministic=False,
|
36 |
+
epoch=-1,
|
37 |
+
):
|
38 |
+
"""Sample shards from the shard list with replacement.
|
39 |
+
|
40 |
+
:param urls: a list of URLs as a Python list or brace notation string
|
41 |
+
"""
|
42 |
+
super().__init__()
|
43 |
+
#urls = wds.shardlists.expand_urls(urls)
|
44 |
+
urls = list(braceexpand.braceexpand(urls))
|
45 |
+
self.urls = urls
|
46 |
+
assert isinstance(self.urls[0], str)
|
47 |
+
self.nshards = nshards
|
48 |
+
self.rng = random.Random()
|
49 |
+
self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
|
50 |
+
self.deterministic = deterministic
|
51 |
+
self.epoch = epoch
|
52 |
+
|
53 |
+
def __iter__(self):
|
54 |
+
"""Return an iterator over the shards."""
|
55 |
+
if isinstance(self.epoch, SharedEpoch):
|
56 |
+
epoch = self.epoch.get_value()
|
57 |
+
else:
|
58 |
+
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
|
59 |
+
# situation as different workers may wrap at different times (or not at all).
|
60 |
+
self.epoch += 1
|
61 |
+
epoch = self.epoch
|
62 |
+
if self.deterministic:
|
63 |
+
# reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed
|
64 |
+
self.rng.seed(self.worker_seed() + epoch)
|
65 |
+
for _ in range(self.nshards):
|
66 |
+
yield dict(url=self.rng.choice(self.urls))
|
67 |
+
|