Spaces:
Runtime error
Runtime error
Mehdi Cherti
commited on
Commit
•
be61cf2
1
Parent(s):
ae26d48
update
Browse files- EMA.py +0 -1
- clip_encoder.py +64 -0
- encoder.py +9 -0
- run.py +103 -3
- scripts/init.sh +15 -0
- scripts/run_hdfml.sh +25 -0
- scripts/run_jurecadc_ddp.sh +4 -1
- test_ddgan.py +280 -64
- train_ddgan.py +158 -60
- utils.py +2 -1
EMA.py
CHANGED
@@ -39,7 +39,6 @@ class EMA(Optimizer):
|
|
39 |
# State initialization
|
40 |
if 'ema' not in state:
|
41 |
state['ema'] = p.data.clone()
|
42 |
-
|
43 |
if p.shape not in params:
|
44 |
params[p.shape] = {'idx': 0, 'data': []}
|
45 |
ema[p.shape] = []
|
|
|
39 |
# State initialization
|
40 |
if 'ema' not in state:
|
41 |
state['ema'] = p.data.clone()
|
|
|
42 |
if p.shape not in params:
|
43 |
params[p.shape] = {'idx': 0, 'data': []}
|
44 |
ema[p.shape] = []
|
clip_encoder.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import open_clip
|
4 |
+
from einops import rearrange
|
5 |
+
|
6 |
+
|
7 |
+
def exists(val):
|
8 |
+
return val is not None
|
9 |
+
|
10 |
+
class CLIPEncoder(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, model, pretrained):
|
13 |
+
super().__init__()
|
14 |
+
self.model = model
|
15 |
+
self.pretrained = pretrained
|
16 |
+
self.model, _, _ = open_clip.create_model_and_transforms(model, pretrained=pretrained)
|
17 |
+
self.output_size = self.model.transformer.width
|
18 |
+
|
19 |
+
def forward(self, texts, return_only_pooled=True):
|
20 |
+
device = next(self.parameters()).device
|
21 |
+
toks = open_clip.tokenize(texts).to(device)
|
22 |
+
x = self.model.token_embedding(toks) # [batch_size, n_ctx, d_model]
|
23 |
+
x = x + self.model.positional_embedding
|
24 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
25 |
+
x = self.model.transformer(x, attn_mask=self.model.attn_mask)
|
26 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
27 |
+
x = self.model.ln_final(x)
|
28 |
+
mask = (toks!=0)
|
29 |
+
pooled = x[torch.arange(x.shape[0]), toks.argmax(dim=-1)] @ self.model.text_projection
|
30 |
+
if return_only_pooled:
|
31 |
+
return pooled
|
32 |
+
else:
|
33 |
+
return pooled, x, mask
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
class CLIPImageEncoder(nn.Module):
|
39 |
+
|
40 |
+
def __init__(self, model_type="ViT-B/32"):
|
41 |
+
super().__init__()
|
42 |
+
import clip
|
43 |
+
self.model, preprocess = clip.load(model_type, device="cpu", jit=False)
|
44 |
+
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
45 |
+
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
46 |
+
mean = torch.tensor(CLIP_MEAN).view(1, 3, 1, 1)
|
47 |
+
std = torch.tensor(CLIP_STD).view(1, 3, 1, 1)
|
48 |
+
self.register_buffer("mean", mean)
|
49 |
+
self.register_buffer("std", std)
|
50 |
+
self.output_size = 512
|
51 |
+
|
52 |
+
def forward_image(self, x):
|
53 |
+
x = torch.nn.functional.interpolate(x, mode='bicubic', size=(224, 224))
|
54 |
+
x = (x-self.mean)/self.std
|
55 |
+
return self.model.encode_image(x)
|
56 |
+
|
57 |
+
def forward_text(self, texts):
|
58 |
+
import clip
|
59 |
+
toks = clip.tokenize(texts, truncate=True).to(self.mean.device)
|
60 |
+
return self.model.encode_text(toks)
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
encoder.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import t5
|
2 |
+
import clip_encoder
|
3 |
+
|
4 |
+
def build_encoder(name, **kwargs):
|
5 |
+
if name.startswith("google"):
|
6 |
+
return t5.T5Encoder(name=name, **kwargs)
|
7 |
+
elif name.startswith("openclip"):
|
8 |
+
_, model, pretrained = name.split("/")
|
9 |
+
return clip_encoder.CLIPEncoder(model, pretrained)
|
run.py
CHANGED
@@ -132,6 +132,8 @@ def ddgan_laion_aesthetic_v2():
|
|
132 |
def ddgan_laion_aesthetic_v3():
|
133 |
cfg = ddgan_laion_aesthetic_v1()
|
134 |
cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
|
|
|
|
|
135 |
return cfg
|
136 |
|
137 |
def ddgan_laion_aesthetic_v4():
|
@@ -146,6 +148,85 @@ def ddgan_laion_aesthetic_v5():
|
|
146 |
cfg['model']['grad_penalty_cond'] = ''
|
147 |
return cfg
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
models = [
|
150 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
151 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
@@ -166,6 +247,23 @@ models = [
|
|
166 |
ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL (continue from 23aug with mismatch and grad penalty and random_resized_crop_v1)
|
167 |
ddgan_laion_aesthetic_v4, # like ddgan_laion_aesthetic_v1 but trained from scratch with OpenAI's ClipEncoder
|
168 |
ddgan_laion_aesthetic_v5, # fine-tune ddgan_laion_aesthetic_v1 with mismatch and cond grad penalty losses
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
]
|
170 |
|
171 |
def get_model(model_name):
|
@@ -174,7 +272,7 @@ def get_model(model_name):
|
|
174 |
return model()
|
175 |
|
176 |
|
177 |
-
def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False):
|
178 |
|
179 |
cfg = get_model(model_name)
|
180 |
model = cfg['model']
|
@@ -204,13 +302,15 @@ def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guida
|
|
204 |
args['scale_factor_h'] = scale_factor_h
|
205 |
args['scale_factor_w'] = scale_factor_w
|
206 |
args['n_mlp'] = model.get("n_mlp")
|
|
|
207 |
if fid:
|
208 |
args['compute_fid'] = ''
|
209 |
args['real_img_dir'] = real_img_dir
|
210 |
args['nb_images_for_fid'] = nb_images_for_fid
|
211 |
if compute_clip_score:
|
212 |
args['compute_clip_score'] = ""
|
213 |
-
|
|
|
214 |
cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
|
215 |
print(cmd)
|
216 |
call(cmd, shell=True)
|
@@ -234,4 +334,4 @@ def eval_results(model_name):
|
|
234 |
|
235 |
if __name__ == "__main__":
|
236 |
from clize import run
|
237 |
-
run([test, eval_results])
|
|
|
132 |
def ddgan_laion_aesthetic_v3():
|
133 |
cfg = ddgan_laion_aesthetic_v1()
|
134 |
cfg['model']['text_encoder'] = "google/t5-v1_1-xl"
|
135 |
+
cfg['model']['mismatch_loss'] = ''
|
136 |
+
cfg['model']['grad_penalty_cond'] = ''
|
137 |
return cfg
|
138 |
|
139 |
def ddgan_laion_aesthetic_v4():
|
|
|
148 |
cfg['model']['grad_penalty_cond'] = ''
|
149 |
return cfg
|
150 |
|
151 |
+
|
152 |
+
|
153 |
+
def ddgan_laion2b_v1():
|
154 |
+
cfg = ddgan_laion_aesthetic_v3()
|
155 |
+
cfg['model']['mismatch_loss'] = ''
|
156 |
+
cfg['model']['grad_penalty_cond'] = ''
|
157 |
+
cfg['model']['num_channels_dae'] = 224
|
158 |
+
cfg['model']['batch_size'] = 2
|
159 |
+
cfg['model']['discr_type'] = "large_cond_attn"
|
160 |
+
cfg['model']['preprocessing'] = 'random_resized_crop_v1'
|
161 |
+
return cfg
|
162 |
+
|
163 |
+
def ddgan_laion_aesthetic_v6():
|
164 |
+
cfg = ddgan_laion_aesthetic_v3()
|
165 |
+
cfg['model']['no_lr_decay'] = ''
|
166 |
+
return cfg
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
def ddgan_laion_aesthetic_v7():
|
171 |
+
cfg = ddgan_laion_aesthetic_v6()
|
172 |
+
cfg['model']['r1_gamma'] = 5
|
173 |
+
return cfg
|
174 |
+
|
175 |
+
|
176 |
+
def ddgan_laion_aesthetic_v8():
|
177 |
+
cfg = ddgan_laion_aesthetic_v6()
|
178 |
+
cfg['model']['num_timesteps'] = 8
|
179 |
+
return cfg
|
180 |
+
|
181 |
+
def ddgan_laion_aesthetic_v9():
|
182 |
+
cfg = ddgan_laion_aesthetic_v3()
|
183 |
+
cfg['model']['num_channels_dae'] = 384
|
184 |
+
return cfg
|
185 |
+
|
186 |
+
def ddgan_sd_v1():
|
187 |
+
cfg = ddgan_laion_aesthetic_v3()
|
188 |
+
return cfg
|
189 |
+
def ddgan_sd_v2():
|
190 |
+
cfg = ddgan_laion_aesthetic_v3()
|
191 |
+
return cfg
|
192 |
+
def ddgan_sd_v3():
|
193 |
+
cfg = ddgan_laion_aesthetic_v3()
|
194 |
+
return cfg
|
195 |
+
def ddgan_sd_v4():
|
196 |
+
cfg = ddgan_laion_aesthetic_v3()
|
197 |
+
return cfg
|
198 |
+
def ddgan_sd_v5():
|
199 |
+
cfg = ddgan_laion_aesthetic_v3()
|
200 |
+
cfg['model']['num_timesteps'] = 8
|
201 |
+
return cfg
|
202 |
+
def ddgan_sd_v6():
|
203 |
+
cfg = ddgan_laion_aesthetic_v3()
|
204 |
+
cfg['model']['num_channels_dae'] = 192
|
205 |
+
return cfg
|
206 |
+
def ddgan_sd_v7():
|
207 |
+
cfg = ddgan_laion_aesthetic_v3()
|
208 |
+
return cfg
|
209 |
+
def ddgan_sd_v8():
|
210 |
+
cfg = ddgan_laion_aesthetic_v3()
|
211 |
+
cfg['model']['image_size'] = 512
|
212 |
+
return cfg
|
213 |
+
def ddgan_laion_aesthetic_v12():
|
214 |
+
cfg = ddgan_laion_aesthetic_v3()
|
215 |
+
return cfg
|
216 |
+
def ddgan_laion_aesthetic_v13():
|
217 |
+
cfg = ddgan_laion_aesthetic_v3()
|
218 |
+
cfg['model']['text_encoder'] = "openclip/ViT-H-14/laion2b_s32b_b79k"
|
219 |
+
return cfg
|
220 |
+
|
221 |
+
def ddgan_laion_aesthetic_v14():
|
222 |
+
cfg = ddgan_laion_aesthetic_v3()
|
223 |
+
cfg['model']['text_encoder'] = "openclip/ViT-H-14/laion2b_s32b_b79k"
|
224 |
+
return cfg
|
225 |
+
def ddgan_sd_v9():
|
226 |
+
cfg = ddgan_laion_aesthetic_v3()
|
227 |
+
cfg['model']['text_encoder'] = "openclip/ViT-H-14/laion2b_s32b_b79k"
|
228 |
+
return cfg
|
229 |
+
|
230 |
models = [
|
231 |
ddgan_cifar10_cond17, # cifar10, cross attn for discr
|
232 |
ddgan_cifar10_cond18, # cifar10, xl encoder
|
|
|
247 |
ddgan_laion_aesthetic_v3, # like ddgan_laion_aesthetic_v1 but trained from scratch with T5-XL (continue from 23aug with mismatch and grad penalty and random_resized_crop_v1)
|
248 |
ddgan_laion_aesthetic_v4, # like ddgan_laion_aesthetic_v1 but trained from scratch with OpenAI's ClipEncoder
|
249 |
ddgan_laion_aesthetic_v5, # fine-tune ddgan_laion_aesthetic_v1 with mismatch and cond grad penalty losses
|
250 |
+
ddgan_laion_aesthetic_v6, # like v3 but without lr decay
|
251 |
+
ddgan_laion_aesthetic_v7, # like v6 but with r1 gamma of 5 instead of 1, trying to constrain the discr more.
|
252 |
+
ddgan_laion_aesthetic_v8, # like v6 but with 8 timesteps
|
253 |
+
ddgan_laion_aesthetic_v9,
|
254 |
+
ddgan_laion_aesthetic_v12,
|
255 |
+
ddgan_laion_aesthetic_v13,
|
256 |
+
ddgan_laion_aesthetic_v14,
|
257 |
+
ddgan_laion2b_v1,
|
258 |
+
ddgan_sd_v1,
|
259 |
+
ddgan_sd_v2,
|
260 |
+
ddgan_sd_v3,
|
261 |
+
ddgan_sd_v4,
|
262 |
+
ddgan_sd_v5,
|
263 |
+
ddgan_sd_v6,
|
264 |
+
ddgan_sd_v7,
|
265 |
+
ddgan_sd_v8,
|
266 |
+
ddgan_sd_v9,
|
267 |
]
|
268 |
|
269 |
def get_model(model_name):
|
|
|
272 |
return model()
|
273 |
|
274 |
|
275 |
+
def test(model_name, *, cond_text="", batch_size:int=None, epoch:int=None, guidance_scale:float=0, fid=False, real_img_dir="", q=0.0, seed=0, nb_images_for_fid=0, scale_factor_h=1, scale_factor_w=1, compute_clip_score=False, eval_name="", scale_method="convolutional"):
|
276 |
|
277 |
cfg = get_model(model_name)
|
278 |
model = cfg['model']
|
|
|
302 |
args['scale_factor_h'] = scale_factor_h
|
303 |
args['scale_factor_w'] = scale_factor_w
|
304 |
args['n_mlp'] = model.get("n_mlp")
|
305 |
+
args['scale_method'] = scale_method
|
306 |
if fid:
|
307 |
args['compute_fid'] = ''
|
308 |
args['real_img_dir'] = real_img_dir
|
309 |
args['nb_images_for_fid'] = nb_images_for_fid
|
310 |
if compute_clip_score:
|
311 |
args['compute_clip_score'] = ""
|
312 |
+
if eval_name:
|
313 |
+
args["eval_name"] = eval_name
|
314 |
cmd = "python -u test_ddgan.py " + " ".join(f"--{k} {v}" for k, v in args.items() if v is not None)
|
315 |
print(cmd)
|
316 |
call(cmd, shell=True)
|
|
|
334 |
|
335 |
if __name__ == "__main__":
|
336 |
from clize import run
|
337 |
+
run([test, eval_results])
|
scripts/init.sh
CHANGED
@@ -32,6 +32,21 @@ if [[ "$machine" == juwelsbooster ]]; then
|
|
32 |
ml torchvision/0.12.0
|
33 |
source /p/project/covidnetx/environments/juwels_booster_2022/bin/activate
|
34 |
fi
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
if [[ "$machine" == jusuf ]]; then
|
36 |
echo not supported
|
37 |
fi
|
|
|
32 |
ml torchvision/0.12.0
|
33 |
source /p/project/covidnetx/environments/juwels_booster_2022/bin/activate
|
34 |
fi
|
35 |
+
if [[ "$machine" == hdfml ]]; then
|
36 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
37 |
+
ml purge
|
38 |
+
ml use $OTHERSTAGES
|
39 |
+
ml Stages/2022
|
40 |
+
ml GCC/11.2.0
|
41 |
+
ml OpenMPI/4.1.2
|
42 |
+
ml CUDA/11.5
|
43 |
+
ml cuDNN/8.3.1.22-CUDA-11.5
|
44 |
+
ml NCCL/2.12.7-1-CUDA-11.5
|
45 |
+
ml PyTorch/1.11-CUDA-11.5
|
46 |
+
ml Horovod/0.24
|
47 |
+
ml torchvision/0.12.0
|
48 |
+
source envs/hdfml/bin/activate
|
49 |
+
fi
|
50 |
if [[ "$machine" == jusuf ]]; then
|
51 |
echo not supported
|
52 |
fi
|
scripts/run_hdfml.sh
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash -x
|
2 |
+
#SBATCH --account=cstdl
|
3 |
+
#SBATCH --nodes=8
|
4 |
+
#SBATCH --ntasks-per-node=4
|
5 |
+
#SBATCH --cpus-per-task=8
|
6 |
+
#SBATCH --time=06:00:00
|
7 |
+
#SBATCH --gres=gpu
|
8 |
+
#SBATCH --partition=batch
|
9 |
+
ml purge
|
10 |
+
ml use $OTHERSTAGES
|
11 |
+
ml Stages/2022
|
12 |
+
ml GCC/11.2.0
|
13 |
+
ml OpenMPI/4.1.2
|
14 |
+
ml CUDA/11.5
|
15 |
+
ml cuDNN/8.3.1.22-CUDA-11.5
|
16 |
+
ml NCCL/2.12.7-1-CUDA-11.5
|
17 |
+
ml PyTorch/1.11-CUDA-11.5
|
18 |
+
ml Horovod/0.24
|
19 |
+
ml torchvision/0.12.0
|
20 |
+
source envs/hdfml/bin/activate
|
21 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
22 |
+
echo "Job id: $SLURM_JOB_ID"
|
23 |
+
export TOKENIZERS_PARALLELISM=false
|
24 |
+
export NCCL_ASYNC_ERROR_HANDLING=1
|
25 |
+
srun python -u $*
|
scripts/run_jurecadc_ddp.sh
CHANGED
@@ -13,5 +13,8 @@ source scripts/init.sh
|
|
13 |
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
14 |
echo "Job id: $SLURM_JOB_ID"
|
15 |
export TOKENIZERS_PARALLELISM=false
|
16 |
-
export NCCL_ASYNC_ERROR_HANDLING=1
|
|
|
|
|
|
|
17 |
srun python -u $*
|
|
|
13 |
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
14 |
echo "Job id: $SLURM_JOB_ID"
|
15 |
export TOKENIZERS_PARALLELISM=false
|
16 |
+
#export NCCL_ASYNC_ERROR_HANDLING=1
|
17 |
+
export NCCL_IB_TIMEOUT=50
|
18 |
+
export UCX_RC_TIMEOUT=4s
|
19 |
+
export NCCL_IB_RETRY_CNT=10
|
20 |
srun python -u $*
|
test_ddgan.py
CHANGED
@@ -86,7 +86,18 @@ class Posterior_Coefficients():
|
|
86 |
self.posterior_mean_coef2 = ((1 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1 - self.alphas_cumprod))
|
87 |
|
88 |
self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
def sample_posterior(coefficients, x_0,x_t, t):
|
91 |
|
92 |
def q_posterior(x_0, x_t, t):
|
@@ -150,10 +161,10 @@ def sample_from_model_classifier_free_guidance(coefficients, generator, n_time,
|
|
150 |
# eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
|
151 |
eps = eps_uncond * (1 - guidance_scale) + eps_cond * guidance_scale
|
152 |
x_0 = (1/torch.sqrt(coefficients.alphas_cumprod[i])) * (x - torch.sqrt(1 - coefficients.alphas_cumprod[i]) * eps)
|
153 |
-
|
154 |
|
155 |
# Dynamic thresholding
|
156 |
-
q =
|
157 |
#print("Before", x_0.min(), x_0.max())
|
158 |
if q:
|
159 |
shape = x_0.shape
|
@@ -180,9 +191,174 @@ def sample_from_model_classifier_free_guidance(coefficients, generator, n_time,
|
|
180 |
return x
|
181 |
|
182 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
#%%
|
184 |
def sample_and_test(args):
|
185 |
torch.manual_seed(args.seed)
|
|
|
186 |
device = 'cuda:0'
|
187 |
text_encoder =build_encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
|
188 |
args.cond_size = text_encoder.output_size
|
@@ -197,10 +373,9 @@ def sample_and_test(args):
|
|
197 |
|
198 |
to_range_0_1 = lambda x: (x + 1.) / 2.
|
199 |
|
200 |
-
|
201 |
netG = NCSNpp(args).to(device)
|
202 |
-
|
203 |
-
|
204 |
if args.epoch_id == -1:
|
205 |
epochs = range(1000)
|
206 |
else:
|
@@ -209,17 +384,27 @@ def sample_and_test(args):
|
|
209 |
for epoch in epochs:
|
210 |
args.epoch_id = epoch
|
211 |
path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
|
|
|
212 |
if not os.path.exists(path):
|
213 |
continue
|
|
|
|
|
|
|
|
|
|
|
214 |
ckpt = torch.load(path, map_location=device)
|
215 |
-
|
|
|
|
|
216 |
|
217 |
-
if args.compute_fid and
|
218 |
continue
|
219 |
print("Eval Epoch", args.epoch_id)
|
220 |
#loading weights from ddp in single gpu
|
|
|
221 |
for key in list(ckpt.keys()):
|
222 |
-
|
|
|
223 |
netG.load_state_dict(ckpt)
|
224 |
netG.eval()
|
225 |
|
@@ -234,7 +419,7 @@ def sample_and_test(args):
|
|
234 |
if not os.path.exists(save_dir):
|
235 |
os.makedirs(save_dir)
|
236 |
|
237 |
-
if args.compute_fid:
|
238 |
from torch.nn.functional import adaptive_avg_pool2d
|
239 |
from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
|
240 |
from pytorch_fid.inception import InceptionV3
|
@@ -252,9 +437,11 @@ def sample_and_test(args):
|
|
252 |
print("Text size:", len(texts))
|
253 |
#print("Iters:", iters_needed)
|
254 |
i = 0
|
255 |
-
|
256 |
-
|
257 |
-
|
|
|
|
|
258 |
|
259 |
if args.compute_clip_score:
|
260 |
import clip
|
@@ -264,19 +451,20 @@ def sample_and_test(args):
|
|
264 |
clip_mean = torch.Tensor(CLIP_MEAN).view(1,-1,1,1).to(device)
|
265 |
clip_std = torch.Tensor(CLIP_STD).view(1,-1,1,1).to(device)
|
266 |
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
|
|
280 |
if args.compute_clip_score:
|
281 |
clip_scores = []
|
282 |
|
@@ -287,7 +475,6 @@ def sample_and_test(args):
|
|
287 |
bs = len(text)
|
288 |
t0 = time.time()
|
289 |
x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
|
290 |
-
#print(x_t_1.shape)
|
291 |
if args.guidance_scale:
|
292 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
293 |
else:
|
@@ -298,45 +485,39 @@ def sample_and_test(args):
|
|
298 |
index = i * args.batch_size + j
|
299 |
torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
|
300 |
"""
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
|
|
|
|
309 |
|
310 |
if args.compute_clip_score:
|
311 |
with torch.no_grad():
|
312 |
clip_ims = torch.nn.functional.interpolate(fake_sample, (224, 224), mode="bicubic")
|
313 |
-
|
|
|
314 |
imf = clip_model.encode_image(clip_ims)
|
315 |
txtf = clip_model.encode_text(clip_txt)
|
316 |
imf = torch.nn.functional.normalize(imf, dim=1)
|
317 |
txtf = torch.nn.functional.normalize(txtf, dim=1)
|
318 |
clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
|
319 |
-
|
320 |
if i % 10 == 0:
|
321 |
-
print('
|
322 |
-
"""
|
323 |
-
if i % 10 == 0:
|
324 |
-
ff = np.concatenate(fake_features)
|
325 |
-
fake_mu = np.mean(ff, axis=0)
|
326 |
-
fake_sigma = np.cov(ff, rowvar=False)
|
327 |
-
fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
|
328 |
-
print("FID", fid)
|
329 |
-
"""
|
330 |
i += 1
|
331 |
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
}
|
340 |
if args.compute_clip_score:
|
341 |
clip_score = torch.cat(clip_scores).mean().item()
|
342 |
results['clip_score'] = clip_score
|
@@ -344,22 +525,54 @@ def sample_and_test(args):
|
|
344 |
with open(dest, "w") as fd:
|
345 |
json.dump(results, fd)
|
346 |
print(results)
|
347 |
-
else:
|
348 |
if args.cond_text.endswith(".txt"):
|
349 |
texts = open(args.cond_text).readlines()
|
350 |
texts = [t.strip() for t in texts]
|
351 |
else:
|
352 |
texts = [args.cond_text] * args.batch_size
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
|
|
|
|
|
|
|
|
|
|
358 |
else:
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
|
364 |
|
365 |
|
@@ -374,6 +587,7 @@ if __name__ == '__main__':
|
|
374 |
parser.add_argument('--compute_clip_score', action='store_true', default=False,
|
375 |
help='whether or not compute CLIP score')
|
376 |
parser.add_argument('--clip_model', type=str,default="ViT-L/14")
|
|
|
377 |
|
378 |
parser.add_argument('--epoch_id', type=int,default=1000)
|
379 |
parser.add_argument('--guidance_scale', type=float,default=0)
|
@@ -381,6 +595,8 @@ if __name__ == '__main__':
|
|
381 |
parser.add_argument('--cond_text', type=str,default="0")
|
382 |
parser.add_argument('--scale_factor_h', type=int,default=1)
|
383 |
parser.add_argument('--scale_factor_w', type=int,default=1)
|
|
|
|
|
384 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
385 |
|
386 |
|
|
|
86 |
self.posterior_mean_coef2 = ((1 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1 - self.alphas_cumprod))
|
87 |
|
88 |
self.posterior_log_variance_clipped = torch.log(self.posterior_variance.clamp(min=1e-20))
|
89 |
+
|
90 |
+
def predict_q_posterior(coefficients, x_0, x_t, t):
|
91 |
+
mean = (
|
92 |
+
extract(coefficients.posterior_mean_coef1, t, x_t.shape) * x_0
|
93 |
+
+ extract(coefficients.posterior_mean_coef2, t, x_t.shape) * x_t
|
94 |
+
)
|
95 |
+
var = extract(coefficients.posterior_variance, t, x_t.shape)
|
96 |
+
log_var_clipped = extract(coefficients.posterior_log_variance_clipped, t, x_t.shape)
|
97 |
+
return mean, var, log_var_clipped
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
def sample_posterior(coefficients, x_0,x_t, t):
|
102 |
|
103 |
def q_posterior(x_0, x_t, t):
|
|
|
161 |
# eps = eps_uncond + guidance_scale * (eps_cond - eps_uncond)
|
162 |
eps = eps_uncond * (1 - guidance_scale) + eps_cond * guidance_scale
|
163 |
x_0 = (1/torch.sqrt(coefficients.alphas_cumprod[i])) * (x - torch.sqrt(1 - coefficients.alphas_cumprod[i]) * eps)
|
164 |
+
#x_0 = x_0_uncond * (1 - guidance_scale) + x_0_cond * guidance_scale
|
165 |
|
166 |
# Dynamic thresholding
|
167 |
+
q = opt.dynamic_thresholding_quantile
|
168 |
#print("Before", x_0.min(), x_0.max())
|
169 |
if q:
|
170 |
shape = x_0.shape
|
|
|
191 |
return x
|
192 |
|
193 |
|
194 |
+
def sample_from_model_classifier_free_guidance_convolutional(coefficients, generator, n_time, x_init, T, opt, text_encoder, cond=None, guidance_scale=0, split_input_params=None):
|
195 |
+
x = x_init
|
196 |
+
null = text_encoder([""] * len(x_init), return_only_pooled=False)
|
197 |
+
#latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
198 |
+
ks = split_input_params["ks"] # eg. (128, 128)
|
199 |
+
stride = split_input_params["stride"] # eg. (64, 64)
|
200 |
+
uf = split_input_params["vqf"]
|
201 |
+
with torch.no_grad():
|
202 |
+
for i in reversed(range(n_time)):
|
203 |
+
t = torch.full((x.size(0),), i, dtype=torch.int64).to(x.device)
|
204 |
+
t_time = t
|
205 |
+
latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
206 |
+
|
207 |
+
fold, unfold, normalization, weighting = get_fold_unfold(x, ks, stride, split_input_params, uf=uf)
|
208 |
+
x = unfold(x)
|
209 |
+
x = x.view((x.shape[0], -1, ks[0], ks[1], x.shape[-1]))
|
210 |
+
x_new_list = []
|
211 |
+
for j in range(x.shape[-1]):
|
212 |
+
x_0_uncond = generator(x[:,:,:,:,j], t_time, latent_z, cond=null)
|
213 |
+
x_0_cond = generator(x[:,:,:,:,j], t_time, latent_z, cond=cond)
|
214 |
+
|
215 |
+
eps_uncond = (x[:,:,:,:,j] - torch.sqrt(coefficients.alphas_cumprod[i]) * x_0_uncond) / torch.sqrt(1 - coefficients.alphas_cumprod[i])
|
216 |
+
eps_cond = (x[:,:,:,:,j] - torch.sqrt(coefficients.alphas_cumprod[i]) * x_0_cond) / torch.sqrt(1 - coefficients.alphas_cumprod[i])
|
217 |
+
|
218 |
+
eps = eps_uncond * (1 - guidance_scale) + eps_cond * guidance_scale
|
219 |
+
x_0 = (1/torch.sqrt(coefficients.alphas_cumprod[i])) * (x[:,:,:,:,j] - torch.sqrt(1 - coefficients.alphas_cumprod[i]) * eps)
|
220 |
+
q = args.dynamic_thresholding_quantile
|
221 |
+
if q:
|
222 |
+
shape = x_0.shape
|
223 |
+
x_0_v = x_0.view(shape[0], -1)
|
224 |
+
d = torch.quantile(torch.abs(x_0_v), q, dim=1, keepdim=True)
|
225 |
+
d.clamp_(min=1)
|
226 |
+
x_0_v = x_0_v.clamp(-d, d) / d
|
227 |
+
x_0 = x_0_v.view(shape)
|
228 |
+
x_new = sample_posterior(coefficients, x_0, x[:,:,:,:,j], t)
|
229 |
+
x_new_list.append(x_new)
|
230 |
+
|
231 |
+
o = torch.stack(x_new_list, axis=-1)
|
232 |
+
#o = o * weighting
|
233 |
+
o = o.view((o.shape[0], -1, o.shape[-1]))
|
234 |
+
decoded = fold(o)
|
235 |
+
decoded = decoded / normalization
|
236 |
+
x = decoded.detach()
|
237 |
+
|
238 |
+
return x
|
239 |
+
|
240 |
+
def sample_from_model_clip_guidance(coefficients, generator, clip_model, n_time, x_init, T, opt, texts, cond=None, guidance_scale=0):
|
241 |
+
x = x_init
|
242 |
+
text_features = torch.nn.functional.normalize(clip_model.forward_text(texts), dim=1)
|
243 |
+
n_time = 16
|
244 |
+
for i in reversed(range(n_time)):
|
245 |
+
t = torch.full((x.size(0),), i%4, dtype=torch.int64).to(x.device)
|
246 |
+
t_time = t
|
247 |
+
latent_z = torch.randn(x.size(0), opt.nz, device=x.device)
|
248 |
+
x.requires_grad = True
|
249 |
+
x_0 = generator(x, t_time, latent_z, cond=cond)
|
250 |
+
x_new = sample_posterior(coefficients, x_0, x, t)
|
251 |
+
x_new_n = (x_new + 1) / 2
|
252 |
+
image_features = torch.nn.functional.normalize(clip_model.forward_image(x_new_n), dim=1)
|
253 |
+
loss = (image_features*text_features).sum(dim=1).mean()
|
254 |
+
x_grad, = torch.autograd.grad(loss, x)
|
255 |
+
lr = 3000
|
256 |
+
x = x.detach()
|
257 |
+
print(x.min(),x.max(), lr*x_grad.min(), lr*x_grad.max())
|
258 |
+
x += x_grad * lr
|
259 |
+
|
260 |
+
with torch.no_grad():
|
261 |
+
x_0 = generator(x, t_time, latent_z, cond=cond)
|
262 |
+
x_new = sample_posterior(coefficients, x_0, x, t)
|
263 |
+
|
264 |
+
x = x_new.detach()
|
265 |
+
print(i)
|
266 |
+
return x
|
267 |
+
|
268 |
+
def meshgrid(h, w):
|
269 |
+
y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
|
270 |
+
x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
|
271 |
+
|
272 |
+
arr = torch.cat([y, x], dim=-1)
|
273 |
+
return arr
|
274 |
+
def delta_border(h, w):
|
275 |
+
"""
|
276 |
+
:param h: height
|
277 |
+
:param w: width
|
278 |
+
:return: normalized distance to image border,
|
279 |
+
wtith min distance = 0 at border and max dist = 0.5 at image center
|
280 |
+
"""
|
281 |
+
lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
|
282 |
+
arr = meshgrid(h, w) / lower_right_corner
|
283 |
+
dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
|
284 |
+
dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
|
285 |
+
edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
|
286 |
+
return edge_dist
|
287 |
+
|
288 |
+
def get_weighting(h, w, Ly, Lx, device, split_input_params):
|
289 |
+
weighting = delta_border(h, w)
|
290 |
+
weighting = torch.clip(weighting, split_input_params["clip_min_weight"],
|
291 |
+
split_input_params["clip_max_weight"], )
|
292 |
+
weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
|
293 |
+
|
294 |
+
if split_input_params["tie_braker"]:
|
295 |
+
L_weighting = delta_border(Ly, Lx)
|
296 |
+
L_weighting = torch.clip(L_weighting,
|
297 |
+
split_input_params["clip_min_tie_weight"],
|
298 |
+
split_input_params["clip_max_tie_weight"])
|
299 |
+
|
300 |
+
L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
|
301 |
+
weighting = weighting * L_weighting
|
302 |
+
return weighting
|
303 |
+
|
304 |
+
def get_fold_unfold(x, kernel_size, stride, split_input_params, uf=1, df=1): # todo load once not every time, shorten code
|
305 |
+
"""
|
306 |
+
:param x: img of size (bs, c, h, w)
|
307 |
+
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
|
308 |
+
"""
|
309 |
+
bs, nc, h, w = x.shape
|
310 |
+
|
311 |
+
# number of crops in image
|
312 |
+
Ly = (h - kernel_size[0]) // stride[0] + 1
|
313 |
+
Lx = (w - kernel_size[1]) // stride[1] + 1
|
314 |
+
|
315 |
+
if uf == 1 and df == 1:
|
316 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
317 |
+
unfold = torch.nn.Unfold(**fold_params)
|
318 |
+
|
319 |
+
fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
|
320 |
+
|
321 |
+
weighting = get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device, split_input_params).to(x.dtype)
|
322 |
+
normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
|
323 |
+
weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
|
324 |
+
|
325 |
+
elif uf > 1 and df == 1:
|
326 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
327 |
+
unfold = torch.nn.Unfold(**fold_params)
|
328 |
+
|
329 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
|
330 |
+
dilation=1, padding=0,
|
331 |
+
stride=(stride[0] * uf, stride[1] * uf))
|
332 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
|
333 |
+
|
334 |
+
weighting = get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device, split_input_params).to(x.dtype)
|
335 |
+
normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
|
336 |
+
weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
|
337 |
+
|
338 |
+
elif df > 1 and uf == 1:
|
339 |
+
fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
|
340 |
+
unfold = torch.nn.Unfold(**fold_params)
|
341 |
+
|
342 |
+
fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
|
343 |
+
dilation=1, padding=0,
|
344 |
+
stride=(stride[0] // df, stride[1] // df))
|
345 |
+
fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
|
346 |
+
|
347 |
+
weighting = get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device, split_input_params).to(x.dtype)
|
348 |
+
normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
|
349 |
+
weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
|
350 |
+
|
351 |
+
else:
|
352 |
+
raise NotImplementedError
|
353 |
+
|
354 |
+
return fold, unfold, normalization, weighting
|
355 |
+
|
356 |
+
|
357 |
+
|
358 |
#%%
|
359 |
def sample_and_test(args):
|
360 |
torch.manual_seed(args.seed)
|
361 |
+
|
362 |
device = 'cuda:0'
|
363 |
text_encoder =build_encoder(name=args.text_encoder, masked_mean=args.masked_mean).to(device)
|
364 |
args.cond_size = text_encoder.output_size
|
|
|
373 |
|
374 |
to_range_0_1 = lambda x: (x + 1.) / 2.
|
375 |
|
376 |
+
print(vars(args))
|
377 |
netG = NCSNpp(args).to(device)
|
378 |
+
|
|
|
379 |
if args.epoch_id == -1:
|
380 |
epochs = range(1000)
|
381 |
else:
|
|
|
384 |
for epoch in epochs:
|
385 |
args.epoch_id = epoch
|
386 |
path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id)
|
387 |
+
next_path = './saved_info/dd_gan/{}/{}/netG_{}.pth'.format(args.dataset, args.exp, args.epoch_id+1)
|
388 |
if not os.path.exists(path):
|
389 |
continue
|
390 |
+
print(path)
|
391 |
+
|
392 |
+
#if not os.path.exists(next_path):
|
393 |
+
# print(f"STOP at {epoch}")
|
394 |
+
# break
|
395 |
ckpt = torch.load(path, map_location=device)
|
396 |
+
suffix = '_' + args.eval_name if args.eval_name else ""
|
397 |
+
dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id, suffix)
|
398 |
+
next_dest = './saved_info/dd_gan/{}/{}/eval_{}{}.json'.format(args.dataset, args.exp, args.epoch_id+1, suffix)
|
399 |
|
400 |
+
if (args.compute_fid or args.compute_clip_score) and os.path.exists(dest):
|
401 |
continue
|
402 |
print("Eval Epoch", args.epoch_id)
|
403 |
#loading weights from ddp in single gpu
|
404 |
+
#print(ckpt.keys())
|
405 |
for key in list(ckpt.keys()):
|
406 |
+
if key.startswith("module"):
|
407 |
+
ckpt[key[7:]] = ckpt.pop(key)
|
408 |
netG.load_state_dict(ckpt)
|
409 |
netG.eval()
|
410 |
|
|
|
419 |
if not os.path.exists(save_dir):
|
420 |
os.makedirs(save_dir)
|
421 |
|
422 |
+
if args.compute_fid or args.compute_clip_score:
|
423 |
from torch.nn.functional import adaptive_avg_pool2d
|
424 |
from pytorch_fid.fid_score import calculate_activation_statistics, calculate_fid_given_paths, ImagePathDataset, compute_statistics_of_path, calculate_frechet_distance
|
425 |
from pytorch_fid.inception import InceptionV3
|
|
|
437 |
print("Text size:", len(texts))
|
438 |
#print("Iters:", iters_needed)
|
439 |
i = 0
|
440 |
+
|
441 |
+
if args.compute_fid:
|
442 |
+
dims = 2048
|
443 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
444 |
+
inceptionv3 = InceptionV3([block_idx]).to(device)
|
445 |
|
446 |
if args.compute_clip_score:
|
447 |
import clip
|
|
|
451 |
clip_mean = torch.Tensor(CLIP_MEAN).view(1,-1,1,1).to(device)
|
452 |
clip_std = torch.Tensor(CLIP_STD).view(1,-1,1,1).to(device)
|
453 |
|
454 |
+
if args.compute_fid:
|
455 |
+
if not args.real_img_dir.endswith("npz"):
|
456 |
+
real_mu, real_sigma = compute_statistics_of_path(
|
457 |
+
args.real_img_dir, inceptionv3, args.batch_size, dims, device,
|
458 |
+
resize=args.image_size,
|
459 |
+
)
|
460 |
+
np.savez("inception_statistics.npz", mu=real_mu, sigma=real_sigma)
|
461 |
+
else:
|
462 |
+
stats = np.load(args.real_img_dir)
|
463 |
+
real_mu = stats['mu']
|
464 |
+
real_sigma = stats['sigma']
|
465 |
+
|
466 |
+
fake_features = []
|
467 |
+
|
468 |
if args.compute_clip_score:
|
469 |
clip_scores = []
|
470 |
|
|
|
475 |
bs = len(text)
|
476 |
t0 = time.time()
|
477 |
x_t_1 = torch.randn(bs, args.num_channels,args.image_size, args.image_size).to(device)
|
|
|
478 |
if args.guidance_scale:
|
479 |
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
480 |
else:
|
|
|
485 |
index = i * args.batch_size + j
|
486 |
torchvision.utils.save_image(x, './generated_samples/{}/{}.jpg'.format(args.dataset, index))
|
487 |
"""
|
488 |
+
|
489 |
+
if args.compute_fid:
|
490 |
+
with torch.no_grad():
|
491 |
+
pred = inceptionv3(fake_sample)[0]
|
492 |
+
# If model output is not scalar, apply global spatial average pooling.
|
493 |
+
# This happens if you choose a dimensionality not equal 2048.
|
494 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
495 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
496 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
497 |
+
fake_features.append(pred)
|
498 |
|
499 |
if args.compute_clip_score:
|
500 |
with torch.no_grad():
|
501 |
clip_ims = torch.nn.functional.interpolate(fake_sample, (224, 224), mode="bicubic")
|
502 |
+
clip_ims = (clip_ims - clip_mean) / clip_std
|
503 |
+
clip_txt = clip.tokenize(text, truncate=True).to(device)
|
504 |
imf = clip_model.encode_image(clip_ims)
|
505 |
txtf = clip_model.encode_text(clip_txt)
|
506 |
imf = torch.nn.functional.normalize(imf, dim=1)
|
507 |
txtf = torch.nn.functional.normalize(txtf, dim=1)
|
508 |
clip_scores.append(((imf * txtf).sum(dim=1)).cpu())
|
509 |
+
|
510 |
if i % 10 == 0:
|
511 |
+
print('evaluating batch ', i, time.time() - t0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
512 |
i += 1
|
513 |
|
514 |
+
results = {}
|
515 |
+
if args.compute_fid:
|
516 |
+
fake_features = np.concatenate(fake_features)
|
517 |
+
fake_mu = np.mean(fake_features, axis=0)
|
518 |
+
fake_sigma = np.cov(fake_features, rowvar=False)
|
519 |
+
fid = calculate_frechet_distance(real_mu, real_sigma, fake_mu, fake_sigma)
|
520 |
+
results['fid'] = fid
|
|
|
521 |
if args.compute_clip_score:
|
522 |
clip_score = torch.cat(clip_scores).mean().item()
|
523 |
results['clip_score'] = clip_score
|
|
|
525 |
with open(dest, "w") as fd:
|
526 |
json.dump(results, fd)
|
527 |
print(results)
|
528 |
+
else:
|
529 |
if args.cond_text.endswith(".txt"):
|
530 |
texts = open(args.cond_text).readlines()
|
531 |
texts = [t.strip() for t in texts]
|
532 |
else:
|
533 |
texts = [args.cond_text] * args.batch_size
|
534 |
+
clip_guidance = False
|
535 |
+
if clip_guidance:
|
536 |
+
from clip_encoder import CLIPImageEncoder
|
537 |
+
cond = text_encoder(texts, return_only_pooled=False)
|
538 |
+
clip_image_model = CLIPImageEncoder().to(device)
|
539 |
+
x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size*args.scale_factor_h, args.image_size*args.scale_factor_w).to(device)
|
540 |
+
fake_sample = sample_from_model_clip_guidance(pos_coeff, netG, clip_image_model, args.num_timesteps, x_t_1,T, args, texts, cond=cond, guidance_scale=args.guidance_scale)
|
541 |
+
fake_sample = to_range_0_1(fake_sample)
|
542 |
+
torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
|
543 |
+
|
544 |
else:
|
545 |
+
cond = text_encoder(texts, return_only_pooled=False)
|
546 |
+
x_t_1 = torch.randn(len(texts), args.num_channels,args.image_size*args.scale_factor_h, args.image_size*args.scale_factor_w).to(device)
|
547 |
+
t0 = time.time()
|
548 |
+
if args.guidance_scale:
|
549 |
+
if args.scale_factor_h > 1 or args.scale_factor_w > 1:
|
550 |
+
if args.scale_method == "convolutional":
|
551 |
+
split_input_params = {
|
552 |
+
"ks": (args.image_size, args.image_size),
|
553 |
+
"stride": (150, 150),
|
554 |
+
"clip_max_tie_weight": 0.5,
|
555 |
+
"clip_min_tie_weight": 0.01,
|
556 |
+
"clip_max_weight": 0.5,
|
557 |
+
"clip_min_weight": 0.01,
|
558 |
+
|
559 |
+
"tie_braker": True,
|
560 |
+
'vqf': 1,
|
561 |
+
}
|
562 |
+
fake_sample = sample_from_model_classifier_free_guidance_convolutional(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale, split_input_params=split_input_params)
|
563 |
+
elif args.scale_method == "larger_input":
|
564 |
+
netG.attn_resolutions = [r * args.scale_factor_w for r in netG.attn_resolutions]
|
565 |
+
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
566 |
+
else:
|
567 |
+
fake_sample = sample_from_model_classifier_free_guidance(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, text_encoder, cond=cond, guidance_scale=args.guidance_scale)
|
568 |
+
else:
|
569 |
+
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1,T, args, cond=cond)
|
570 |
+
|
571 |
+
print(time.time() - t0)
|
572 |
+
fake_sample = to_range_0_1(fake_sample)
|
573 |
+
torchvision.utils.save_image(fake_sample, './samples_{}.jpg'.format(args.dataset))
|
574 |
+
|
575 |
+
|
576 |
|
577 |
|
578 |
|
|
|
587 |
parser.add_argument('--compute_clip_score', action='store_true', default=False,
|
588 |
help='whether or not compute CLIP score')
|
589 |
parser.add_argument('--clip_model', type=str,default="ViT-L/14")
|
590 |
+
parser.add_argument('--eval_name', type=str,default="")
|
591 |
|
592 |
parser.add_argument('--epoch_id', type=int,default=1000)
|
593 |
parser.add_argument('--guidance_scale', type=float,default=0)
|
|
|
595 |
parser.add_argument('--cond_text', type=str,default="0")
|
596 |
parser.add_argument('--scale_factor_h', type=int,default=1)
|
597 |
parser.add_argument('--scale_factor_w', type=int,default=1)
|
598 |
+
parser.add_argument('--scale_method', type=str,default="convolutional")
|
599 |
+
|
600 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
601 |
|
602 |
|
train_ddgan.py
CHANGED
@@ -5,7 +5,7 @@
|
|
5 |
# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
|
6 |
# ---------------------------------------------------------------
|
7 |
|
8 |
-
|
9 |
import argparse
|
10 |
import torch
|
11 |
import numpy as np
|
@@ -30,6 +30,7 @@ import shutil
|
|
30 |
import logging
|
31 |
from encoder import build_encoder
|
32 |
from utils import ResampledShards2
|
|
|
33 |
|
34 |
|
35 |
def log_and_continue(exn):
|
@@ -194,23 +195,29 @@ def sample_from_model(coefficients, generator, n_time, x_init, T, opt, cond=None
|
|
194 |
|
195 |
return x
|
196 |
|
197 |
-
|
198 |
|
199 |
def filter_no_caption(sample):
|
200 |
return 'txt' in sample
|
201 |
|
202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
|
204 |
def train(rank, gpu, args):
|
205 |
from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
|
206 |
from score_sde.models.ncsnpp_generator_adagn import NCSNpp
|
207 |
from EMA import EMA
|
208 |
|
209 |
-
torch.manual_seed(args.seed + rank)
|
210 |
-
torch.cuda.manual_seed(args.seed + rank)
|
211 |
-
torch.cuda.manual_seed_all(args.seed + rank)
|
212 |
device = "cuda"
|
213 |
-
|
214 |
batch_size = args.batch_size
|
215 |
|
216 |
nz = args.nz #latent dimension
|
@@ -270,11 +277,12 @@ def train(rank, gpu, args):
|
|
270 |
])
|
271 |
elif args.preprocessing == "random_resized_crop_v1":
|
272 |
train_transform = transforms.Compose([
|
273 |
-
transforms.RandomResizedCrop(
|
274 |
transforms.ToTensor(),
|
275 |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
276 |
])
|
277 |
-
|
|
|
278 |
pipeline.extend([
|
279 |
wds.split_by_node,
|
280 |
wds.split_by_worker,
|
@@ -339,6 +347,13 @@ def train(rank, gpu, args):
|
|
339 |
t_emb_dim = args.t_emb_dim,
|
340 |
cond_size=text_encoder.output_size,
|
341 |
act=nn.LeakyReLU(0.2)).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
elif args.discr_type == "large_cond_attn":
|
343 |
netD = CondAttnDiscriminator(
|
344 |
nc = 2*args.num_channels,
|
@@ -350,6 +365,15 @@ def train(rank, gpu, args):
|
|
350 |
broadcast_params(netG.parameters())
|
351 |
broadcast_params(netD.parameters())
|
352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
|
354 |
optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
|
355 |
|
@@ -358,9 +382,16 @@ def train(rank, gpu, args):
|
|
358 |
|
359 |
schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
|
360 |
schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
|
362 |
-
|
363 |
-
|
|
|
364 |
|
365 |
exp = args.exp
|
366 |
parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
|
@@ -377,6 +408,10 @@ def train(rank, gpu, args):
|
|
377 |
T = get_time_schedule(args, device)
|
378 |
|
379 |
checkpoint_file = os.path.join(exp_path, 'content.pth')
|
|
|
|
|
|
|
|
|
380 |
if args.resume and os.path.exists(checkpoint_file):
|
381 |
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
382 |
init_epoch = checkpoint['epoch']
|
@@ -395,7 +430,7 @@ def train(rank, gpu, args):
|
|
395 |
.format(checkpoint['epoch']))
|
396 |
else:
|
397 |
global_step, epoch, init_epoch = 0, 0, 0
|
398 |
-
use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn")
|
399 |
for epoch in range(init_epoch, args.num_epoch+1):
|
400 |
if args.dataset == "wds":
|
401 |
os.environ["WDS_EPOCH"] = str(epoch)
|
@@ -403,6 +438,7 @@ def train(rank, gpu, args):
|
|
403 |
train_sampler.set_epoch(epoch)
|
404 |
|
405 |
for iteration, (x, y) in enumerate(data_loader):
|
|
|
406 |
if args.dataset != "wds":
|
407 |
y = [str(yi) for yi in y.tolist()]
|
408 |
|
@@ -437,15 +473,15 @@ def train(rank, gpu, args):
|
|
437 |
cond_for_discr.requires_grad = True
|
438 |
|
439 |
# train with real
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
|
445 |
|
446 |
errD_real.backward(retain_graph=True)
|
447 |
|
448 |
-
|
449 |
if args.lazy_reg is None:
|
450 |
if args.grad_penalty_cond:
|
451 |
inputs = (x_t,) + (cond,) if use_cond_attn_discr else (cond_for_discr,)
|
@@ -491,26 +527,36 @@ def train(rank, gpu, args):
|
|
491 |
|
492 |
# train with fake
|
493 |
latent_z = torch.randn(batch_size, nz, device=device)
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
|
500 |
-
|
501 |
-
|
502 |
-
|
|
|
|
|
503 |
|
504 |
if args.mismatch_loss:
|
505 |
# following https://github.com/tobran/DF-GAN/blob/bc38a4f795c294b09b4ef5579cd4ff78807e5b96/code/lib/modules.py,
|
506 |
# we add a discr loss for (real image, non matching text)
|
507 |
#inds = torch.flip(torch.arange(len(x_t)), dims=(0,))
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
|
|
514 |
|
515 |
errD_fake.backward()
|
516 |
|
@@ -534,58 +580,106 @@ def train(rank, gpu, args):
|
|
534 |
|
535 |
latent_z = torch.randn(batch_size, nz,device=device)
|
536 |
|
537 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
538 |
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
errG = F.softplus(-output)
|
546 |
-
errG = errG.mean()
|
547 |
|
548 |
errG.backward()
|
549 |
optimizerG.step()
|
550 |
|
551 |
-
|
|
|
|
|
|
|
|
|
552 |
|
553 |
global_step += 1
|
|
|
|
|
554 |
if iteration % 100 == 0:
|
555 |
if rank == 0:
|
556 |
print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
|
|
|
557 |
if iteration % 1000 == 0:
|
558 |
x_t_1 = torch.randn_like(real_data)
|
559 |
-
|
|
|
560 |
if rank == 0:
|
561 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
|
562 |
-
|
563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
577 |
|
578 |
if not args.no_lr_decay:
|
579 |
|
580 |
schedulerG.step()
|
581 |
schedulerD.step()
|
582 |
-
|
583 |
if rank == 0:
|
584 |
if epoch % 10 == 0:
|
585 |
torchvision.utils.save_image(x_pos_sample, os.path.join(exp_path, 'xpos_epoch_{}.png'.format(epoch)), normalize=True)
|
586 |
|
587 |
x_t_1 = torch.randn_like(real_data)
|
588 |
-
|
|
|
589 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)), normalize=True)
|
590 |
|
591 |
if args.save_content:
|
@@ -606,7 +700,8 @@ def train(rank, gpu, args):
|
|
606 |
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
607 |
if args.use_ema:
|
608 |
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
609 |
-
|
|
|
610 |
|
611 |
|
612 |
def init_processes(rank, size, fn, args):
|
@@ -641,6 +736,8 @@ if __name__ == '__main__':
|
|
641 |
parser.add_argument('--mismatch_loss', action='store_true',default=False)
|
642 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
643 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
|
|
|
|
644 |
|
645 |
parser.add_argument('--image_size', type=int, default=32,
|
646 |
help='size of image')
|
@@ -728,6 +825,7 @@ if __name__ == '__main__':
|
|
728 |
parser.add_argument('--save_ckpt_every', type=int, default=25, help='save ckpt every x epochs')
|
729 |
parser.add_argument('--discr_type', type=str, default="large")
|
730 |
parser.add_argument('--preprocessing', type=str, default="resize")
|
|
|
731 |
|
732 |
###ddp
|
733 |
parser.add_argument('--num_proc_node', type=int, default=1,
|
@@ -746,4 +844,4 @@ if __name__ == '__main__':
|
|
746 |
args.world_size = int(os.getenv("SLURM_NTASKS"))
|
747 |
args.rank = int(os.environ['SLURM_PROCID'])
|
748 |
# size = args.num_process_per_node
|
749 |
-
init_processes(args.rank, args.world_size, train, args)
|
|
|
5 |
# for Denoising Diffusion GAN. To view a copy of this license, see the LICENSE file.
|
6 |
# ---------------------------------------------------------------
|
7 |
|
8 |
+
from glob import glob
|
9 |
import argparse
|
10 |
import torch
|
11 |
import numpy as np
|
|
|
30 |
import logging
|
31 |
from encoder import build_encoder
|
32 |
from utils import ResampledShards2
|
33 |
+
from torch.utils.tensorboard import SummaryWriter
|
34 |
|
35 |
|
36 |
def log_and_continue(exn):
|
|
|
195 |
|
196 |
return x
|
197 |
|
198 |
+
from contextlib import suppress
|
199 |
|
200 |
def filter_no_caption(sample):
|
201 |
return 'txt' in sample
|
202 |
|
203 |
+
def get_autocast(precision):
|
204 |
+
if precision == 'amp':
|
205 |
+
return torch.cuda.amp.autocast
|
206 |
+
elif precision == 'amp_bfloat16':
|
207 |
+
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
208 |
+
else:
|
209 |
+
return suppress
|
210 |
|
211 |
def train(rank, gpu, args):
|
212 |
from score_sde.models.discriminator import Discriminator_small, Discriminator_large, CondAttnDiscriminator, SmallCondAttnDiscriminator
|
213 |
from score_sde.models.ncsnpp_generator_adagn import NCSNpp
|
214 |
from EMA import EMA
|
215 |
|
216 |
+
#torch.manual_seed(args.seed + rank)
|
217 |
+
#torch.cuda.manual_seed(args.seed + rank)
|
218 |
+
#torch.cuda.manual_seed_all(args.seed + rank)
|
219 |
device = "cuda"
|
220 |
+
autocast = get_autocast(args.precision)
|
221 |
batch_size = args.batch_size
|
222 |
|
223 |
nz = args.nz #latent dimension
|
|
|
277 |
])
|
278 |
elif args.preprocessing == "random_resized_crop_v1":
|
279 |
train_transform = transforms.Compose([
|
280 |
+
transforms.RandomResizedCrop(args.image_size, scale=(0.95, 1.0), interpolation=3),
|
281 |
transforms.ToTensor(),
|
282 |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
|
283 |
])
|
284 |
+
shards = glob(os.path.join(args.dataset_root, "*.tar")) if os.path.isdir(args.dataset_root) else args.dataset_root
|
285 |
+
pipeline = [ResampledShards2(shards)]
|
286 |
pipeline.extend([
|
287 |
wds.split_by_node,
|
288 |
wds.split_by_worker,
|
|
|
347 |
t_emb_dim = args.t_emb_dim,
|
348 |
cond_size=text_encoder.output_size,
|
349 |
act=nn.LeakyReLU(0.2)).to(device)
|
350 |
+
elif args.discr_type == "large_attn_pool":
|
351 |
+
netD = Discriminator_large(nc = 2*args.num_channels, ngf = args.ngf,
|
352 |
+
t_emb_dim = args.t_emb_dim,
|
353 |
+
cond_size=text_encoder.output_size,
|
354 |
+
attn_pool=True,
|
355 |
+
act=nn.LeakyReLU(0.2)).to(device)
|
356 |
+
|
357 |
elif args.discr_type == "large_cond_attn":
|
358 |
netD = CondAttnDiscriminator(
|
359 |
nc = 2*args.num_channels,
|
|
|
365 |
broadcast_params(netG.parameters())
|
366 |
broadcast_params(netD.parameters())
|
367 |
|
368 |
+
if args.fsdp:
|
369 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
370 |
+
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
|
371 |
+
netG = FSDP(
|
372 |
+
netG,
|
373 |
+
flatten_parameters=True,
|
374 |
+
verbose=True,
|
375 |
+
)
|
376 |
+
|
377 |
optimizerD = optim.Adam(netD.parameters(), lr=args.lr_d, betas = (args.beta1, args.beta2))
|
378 |
optimizerG = optim.Adam(netG.parameters(), lr=args.lr_g, betas = (args.beta1, args.beta2))
|
379 |
|
|
|
382 |
|
383 |
schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, args.num_epoch, eta_min=1e-5)
|
384 |
schedulerD = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerD, args.num_epoch, eta_min=1e-5)
|
385 |
+
|
386 |
+
if args.fsdp:
|
387 |
+
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
|
388 |
+
else:
|
389 |
+
netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
|
390 |
+
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
|
391 |
|
392 |
+
if args.grad_checkpointing:
|
393 |
+
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
|
394 |
+
netG = checkpoint_wrapper(netG)
|
395 |
|
396 |
exp = args.exp
|
397 |
parent_dir = "./saved_info/dd_gan/{}".format(args.dataset)
|
|
|
408 |
T = get_time_schedule(args, device)
|
409 |
|
410 |
checkpoint_file = os.path.join(exp_path, 'content.pth')
|
411 |
+
|
412 |
+
if rank == 0:
|
413 |
+
log_writer = SummaryWriter(exp_path)
|
414 |
+
|
415 |
if args.resume and os.path.exists(checkpoint_file):
|
416 |
checkpoint = torch.load(checkpoint_file, map_location="cpu")
|
417 |
init_epoch = checkpoint['epoch']
|
|
|
430 |
.format(checkpoint['epoch']))
|
431 |
else:
|
432 |
global_step, epoch, init_epoch = 0, 0, 0
|
433 |
+
use_cond_attn_discr = args.discr_type in ("large_cond_attn", "small_cond_attn", "large_attn_pool")
|
434 |
for epoch in range(init_epoch, args.num_epoch+1):
|
435 |
if args.dataset == "wds":
|
436 |
os.environ["WDS_EPOCH"] = str(epoch)
|
|
|
438 |
train_sampler.set_epoch(epoch)
|
439 |
|
440 |
for iteration, (x, y) in enumerate(data_loader):
|
441 |
+
#print(x.shape)
|
442 |
if args.dataset != "wds":
|
443 |
y = [str(yi) for yi in y.tolist()]
|
444 |
|
|
|
473 |
cond_for_discr.requires_grad = True
|
474 |
|
475 |
# train with real
|
476 |
+
with autocast():
|
477 |
+
D_real = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
|
478 |
+
errD_real = F.softplus(-D_real)
|
479 |
+
errD_real = errD_real.mean()
|
480 |
|
481 |
|
482 |
errD_real.backward(retain_graph=True)
|
483 |
|
484 |
+
grad_penalty = None
|
485 |
if args.lazy_reg is None:
|
486 |
if args.grad_penalty_cond:
|
487 |
inputs = (x_t,) + (cond,) if use_cond_attn_discr else (cond_for_discr,)
|
|
|
527 |
|
528 |
# train with fake
|
529 |
latent_z = torch.randn(batch_size, nz, device=device)
|
530 |
+
with autocast():
|
531 |
+
if args.grad_checkpointing:
|
532 |
+
ginp = x_tp1.detach()
|
533 |
+
ginp.requires_grad = True
|
534 |
+
latent_z.requires_grad = True
|
535 |
+
cond_pooled.requires_grad = True
|
536 |
+
cond.requires_grad = True
|
537 |
+
#cond_mask.requires_grad = True
|
538 |
+
x_0_predict = netG(ginp, t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
539 |
+
else:
|
540 |
+
x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
541 |
+
x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
|
542 |
|
543 |
+
output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
|
544 |
+
|
545 |
+
|
546 |
+
errD_fake = F.softplus(output)
|
547 |
+
errD_fake = errD_fake.mean()
|
548 |
|
549 |
if args.mismatch_loss:
|
550 |
# following https://github.com/tobran/DF-GAN/blob/bc38a4f795c294b09b4ef5579cd4ff78807e5b96/code/lib/modules.py,
|
551 |
# we add a discr loss for (real image, non matching text)
|
552 |
#inds = torch.flip(torch.arange(len(x_t)), dims=(0,))
|
553 |
+
with autocast():
|
554 |
+
inds = torch.cat([torch.arange(1,len(x_t)),torch.arange(1)])
|
555 |
+
cond_for_discr_mis = (cond_pooled[inds], cond[inds], cond_mask[inds]) if use_cond_attn_discr else cond_pooled[inds]
|
556 |
+
D_real_mis = netD(x_t, t, x_tp1.detach(), cond=cond_for_discr_mis).view(-1)
|
557 |
+
errD_real_mis = F.softplus(D_real_mis)
|
558 |
+
errD_real_mis = errD_real_mis.mean()
|
559 |
+
errD_fake = errD_fake * 0.5 + errD_real_mis * 0.5
|
560 |
|
561 |
errD_fake.backward()
|
562 |
|
|
|
580 |
|
581 |
latent_z = torch.randn(batch_size, nz,device=device)
|
582 |
|
583 |
+
with autocast():
|
584 |
+
if args.grad_checkpointing:
|
585 |
+
ginp = x_tp1.detach()
|
586 |
+
ginp.requires_grad = True
|
587 |
+
latent_z.requires_grad = True
|
588 |
+
cond_pooled.requires_grad = True
|
589 |
+
cond.requires_grad = True
|
590 |
+
#cond_mask.requires_grad = True
|
591 |
+
x_0_predict = netG(ginp, t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
592 |
+
else:
|
593 |
+
x_0_predict = netG(x_tp1.detach(), t, latent_z, cond=(cond_pooled, cond, cond_mask))
|
594 |
+
x_pos_sample = sample_posterior(pos_coeff, x_0_predict, x_tp1, t)
|
595 |
|
596 |
+
output = netD(x_pos_sample, t, x_tp1.detach(), cond=cond_for_discr).view(-1)
|
597 |
+
|
598 |
+
|
599 |
+
errG = F.softplus(-output)
|
600 |
+
errG = errG.mean()
|
|
|
|
|
|
|
601 |
|
602 |
errG.backward()
|
603 |
optimizerG.step()
|
604 |
|
605 |
+
if (iteration % 10 == 0) and (rank == 0):
|
606 |
+
log_writer.add_scalar('g_loss', errG.item(), global_step)
|
607 |
+
log_writer.add_scalar('d_loss', errD.item(), global_step)
|
608 |
+
if grad_penalty is not None:
|
609 |
+
log_writer.add_scalar('grad_penalty', grad_penalty.item(), global_step)
|
610 |
|
611 |
global_step += 1
|
612 |
+
|
613 |
+
|
614 |
if iteration % 100 == 0:
|
615 |
if rank == 0:
|
616 |
print('epoch {} iteration{}, G Loss: {}, D Loss: {}'.format(epoch,iteration, errG.item(), errD.item()))
|
617 |
+
print('Global step:', global_step)
|
618 |
if iteration % 1000 == 0:
|
619 |
x_t_1 = torch.randn_like(real_data)
|
620 |
+
with autocast():
|
621 |
+
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
|
622 |
if rank == 0:
|
623 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}_iteration_{}.png'.format(epoch, iteration)), normalize=True)
|
624 |
+
|
625 |
+
if args.save_content:
|
626 |
+
dist.barrier()
|
627 |
+
print('Saving content.')
|
628 |
+
def to_cpu(d):
|
629 |
+
for k, v in d.items():
|
630 |
+
d[k] = v.cpu()
|
631 |
+
return d
|
632 |
+
|
633 |
+
if args.fsdp:
|
634 |
+
netG_state_dict = to_cpu(netG.state_dict())
|
635 |
+
netD_state_dict = to_cpu(netD.state_dict())
|
636 |
+
#netG_optim_state_dict = (netG.gather_full_optim_state_dict(optimizerG))
|
637 |
+
netG_optim_state_dict = optimizerG.state_dict()
|
638 |
+
#print(netG_optim_state_dict)
|
639 |
+
netD_optim_state_dict = (optimizerD.state_dict())
|
640 |
content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
|
641 |
+
'netG_dict': netG_state_dict, 'optimizerG': netG_optim_state_dict,
|
642 |
+
'schedulerG': schedulerG.state_dict(), 'netD_dict': netD_state_dict,
|
643 |
+
'optimizerD': netD_optim_state_dict, 'schedulerD': schedulerD.state_dict()}
|
644 |
+
if rank == 0:
|
645 |
+
torch.save(content, os.path.join(exp_path, 'content.pth'))
|
646 |
+
torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
|
647 |
+
if args.use_ema:
|
648 |
+
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
649 |
+
if args.use_ema and rank == 0:
|
650 |
+
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
651 |
+
if args.use_ema:
|
652 |
+
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
653 |
+
#if args.use_ema:
|
654 |
+
# dist.barrier()
|
655 |
+
print("Saved content")
|
656 |
+
else:
|
657 |
+
if rank == 0:
|
658 |
+
content = {'epoch': epoch + 1, 'global_step': global_step, 'args': args,
|
659 |
+
'netG_dict': netG.state_dict(), 'optimizerG': optimizerG.state_dict(),
|
660 |
+
'schedulerG': schedulerG.state_dict(), 'netD_dict': netD.state_dict(),
|
661 |
+
'optimizerD': optimizerD.state_dict(), 'schedulerD': schedulerD.state_dict()}
|
662 |
+
torch.save(content, os.path.join(exp_path, 'content.pth'))
|
663 |
+
torch.save(content, os.path.join(exp_path, 'content_backup.pth'))
|
664 |
+
if args.use_ema:
|
665 |
+
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
666 |
+
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
667 |
+
if args.use_ema:
|
668 |
+
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
669 |
+
|
670 |
|
671 |
if not args.no_lr_decay:
|
672 |
|
673 |
schedulerG.step()
|
674 |
schedulerD.step()
|
675 |
+
"""
|
676 |
if rank == 0:
|
677 |
if epoch % 10 == 0:
|
678 |
torchvision.utils.save_image(x_pos_sample, os.path.join(exp_path, 'xpos_epoch_{}.png'.format(epoch)), normalize=True)
|
679 |
|
680 |
x_t_1 = torch.randn_like(real_data)
|
681 |
+
with autocast():
|
682 |
+
fake_sample = sample_from_model(pos_coeff, netG, args.num_timesteps, x_t_1, T, args, cond=(cond_pooled, cond, cond_mask))
|
683 |
torchvision.utils.save_image(fake_sample, os.path.join(exp_path, 'sample_discrete_epoch_{}.png'.format(epoch)), normalize=True)
|
684 |
|
685 |
if args.save_content:
|
|
|
700 |
torch.save(netG.state_dict(), os.path.join(exp_path, 'netG_{}.pth'.format(epoch)))
|
701 |
if args.use_ema:
|
702 |
optimizerG.swap_parameters_with_ema(store_params_in_ema=True)
|
703 |
+
dist.barrier()
|
704 |
+
"""
|
705 |
|
706 |
|
707 |
def init_processes(rank, size, fn, args):
|
|
|
736 |
parser.add_argument('--mismatch_loss', action='store_true',default=False)
|
737 |
parser.add_argument('--text_encoder', type=str, default="google/t5-v1_1-base")
|
738 |
parser.add_argument('--cross_attention', action='store_true',default=False)
|
739 |
+
parser.add_argument('--fsdp', action='store_true',default=False)
|
740 |
+
parser.add_argument('--grad_checkpointing', action='store_true',default=False)
|
741 |
|
742 |
parser.add_argument('--image_size', type=int, default=32,
|
743 |
help='size of image')
|
|
|
825 |
parser.add_argument('--save_ckpt_every', type=int, default=25, help='save ckpt every x epochs')
|
826 |
parser.add_argument('--discr_type', type=str, default="large")
|
827 |
parser.add_argument('--preprocessing', type=str, default="resize")
|
828 |
+
parser.add_argument('--precision', type=str, default="fp32")
|
829 |
|
830 |
###ddp
|
831 |
parser.add_argument('--num_proc_node', type=int, default=1,
|
|
|
844 |
args.world_size = int(os.getenv("SLURM_NTASKS"))
|
845 |
args.rank = int(os.environ['SLURM_PROCID'])
|
846 |
# size = args.num_process_per_node
|
847 |
+
init_processes(args.rank, args.world_size, train, args)
|
utils.py
CHANGED
@@ -41,7 +41,8 @@ class ResampledShards2(IterableDataset):
|
|
41 |
"""
|
42 |
super().__init__()
|
43 |
#urls = wds.shardlists.expand_urls(urls)
|
44 |
-
urls
|
|
|
45 |
self.urls = urls
|
46 |
assert isinstance(self.urls[0], str)
|
47 |
self.nshards = nshards
|
|
|
41 |
"""
|
42 |
super().__init__()
|
43 |
#urls = wds.shardlists.expand_urls(urls)
|
44 |
+
if type(urls) != list:
|
45 |
+
urls = list(braceexpand.braceexpand(urls))
|
46 |
self.urls = urls
|
47 |
assert isinstance(self.urls[0], str)
|
48 |
self.nshards = nshards
|