Upload 225 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- configs/model/audioldm.yaml +24 -0
- configs/model/clap.yaml +10 -0
- configs/model/clip.yaml +22 -0
- configs/model/codi.yaml +23 -0
- configs/model/openai_unet.yaml +85 -0
- configs/model/optimus.yaml +107 -0
- configs/model/prova.yaml +85 -0
- configs/model/sd.yaml +20 -0
- configs/model/thesis_model.yaml +21 -0
- core/__init__.py +0 -0
- core/__pycache__/__init__.cpython-38.pyc +0 -0
- core/__pycache__/cfg_helper.cpython-38.pyc +0 -0
- core/__pycache__/cfg_holder.cpython-38.pyc +0 -0
- core/__pycache__/sync.cpython-38.pyc +0 -0
- core/cfg_helper.py +665 -0
- core/cfg_holder.py +33 -0
- core/common/__pycache__/utils.cpython-38.pyc +0 -0
- core/common/registry.py +86 -0
- core/common/utils.py +412 -0
- core/models/__init__.py +4 -0
- core/models/__pycache__/__init__.cpython-38.pyc +0 -0
- core/models/__pycache__/codi.cpython-38.pyc +0 -0
- core/models/__pycache__/codi_2.cpython-38.pyc +0 -0
- core/models/__pycache__/dani_model.cpython-38.pyc +0 -0
- core/models/__pycache__/ema.cpython-38.pyc +0 -0
- core/models/__pycache__/model_module_infer.cpython-38.pyc +0 -0
- core/models/__pycache__/sd.cpython-38.pyc +0 -0
- core/models/codi.py +227 -0
- core/models/codi_2.py +221 -0
- core/models/common/__pycache__/get_model.cpython-38.pyc +0 -0
- core/models/common/__pycache__/get_optimizer.cpython-38.pyc +0 -0
- core/models/common/__pycache__/get_scheduler.cpython-38.pyc +0 -0
- core/models/common/__pycache__/utils.cpython-38.pyc +0 -0
- core/models/common/get_model.py +88 -0
- core/models/common/get_optimizer.py +50 -0
- core/models/common/get_scheduler.py +273 -0
- core/models/common/utils.py +310 -0
- core/models/dani_model.py +170 -0
- core/models/ddim/__pycache__/ddim.cpython-38.pyc +0 -0
- core/models/ddim/__pycache__/ddim_vd.cpython-38.pyc +0 -0
- core/models/ddim/__pycache__/diffusion_utils.cpython-38.pyc +0 -0
- core/models/ddim/ddim.py +224 -0
- core/models/ddim/ddim_vd.py +175 -0
- core/models/ddim/diffusion_utils.py +273 -0
- core/models/ema.py +76 -0
- core/models/encoders/__pycache__/clap.cpython-311.pyc +0 -0
- core/models/encoders/__pycache__/clap.cpython-38.pyc +0 -0
- core/models/encoders/__pycache__/clip.cpython-311.pyc +0 -0
- core/models/encoders/__pycache__/clip.cpython-38.pyc +0 -0
- core/models/encoders/clap.py +134 -0
configs/model/audioldm.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################
|
2 |
+
# audioldm autoencoder #
|
3 |
+
########################
|
4 |
+
|
5 |
+
|
6 |
+
audioldm_autoencoder:
|
7 |
+
type: audioldm_autoencoder
|
8 |
+
args:
|
9 |
+
embed_dim: 8
|
10 |
+
monitor: val/rec_loss
|
11 |
+
ddconfig:
|
12 |
+
double_z: True
|
13 |
+
z_channels: 8
|
14 |
+
resolution: 256
|
15 |
+
downsample_time: False
|
16 |
+
in_channels: 1
|
17 |
+
out_ch: 1
|
18 |
+
ch: 128
|
19 |
+
ch_mult: [1, 2, 4]
|
20 |
+
num_res_blocks: 2
|
21 |
+
attn_resolutions: []
|
22 |
+
dropout: 0.0
|
23 |
+
lossconfig:
|
24 |
+
target: torch.nn.Identity
|
configs/model/clap.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
######################
|
2 |
+
# clap audio encoder #
|
3 |
+
######################
|
4 |
+
|
5 |
+
|
6 |
+
clap_audio:
|
7 |
+
type: clap_audio
|
8 |
+
args:
|
9 |
+
amodel: "HTSAT-large"
|
10 |
+
joint_embed_shape: 768
|
configs/model/clip.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
##############################
|
2 |
+
# clip vision & text encoder #
|
3 |
+
##############################
|
4 |
+
|
5 |
+
clip:
|
6 |
+
symbol: clip
|
7 |
+
args: {}
|
8 |
+
|
9 |
+
clip_frozen:
|
10 |
+
super_cfg: clip
|
11 |
+
type: clip_frozen
|
12 |
+
args: {}
|
13 |
+
|
14 |
+
clip_text:
|
15 |
+
super_cfg: clip
|
16 |
+
type: clip_text
|
17 |
+
args: {}
|
18 |
+
|
19 |
+
clip_vision:
|
20 |
+
super_cfg: clip
|
21 |
+
type: clip_vision
|
22 |
+
args: {}
|
configs/model/codi.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########
|
2 |
+
# CoDi #
|
3 |
+
########
|
4 |
+
|
5 |
+
codi:
|
6 |
+
type: codi
|
7 |
+
symbol: codi
|
8 |
+
find_unused_parameters: true
|
9 |
+
args:
|
10 |
+
audioldm_cfg: MODEL(audioldm_autoencoder)
|
11 |
+
autokl_cfg: MODEL(sd_autoencoder)
|
12 |
+
optimus_cfg: MODEL(optimus_vae)
|
13 |
+
clip_cfg: MODEL(clip_frozen)
|
14 |
+
clap_cfg: MODEL(clap_audio)
|
15 |
+
unet_config: MODEL(openai_unet_codi)
|
16 |
+
beta_linear_start: 0.00085
|
17 |
+
beta_linear_end: 0.012
|
18 |
+
timesteps: 1000
|
19 |
+
vision_scale_factor: 0.18215
|
20 |
+
text_scale_factor: 4.3108
|
21 |
+
audio_scale_factor: 0.9228
|
22 |
+
use_ema: false
|
23 |
+
parameterization : "eps"
|
configs/model/openai_unet.yaml
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai_unet_sd:
|
2 |
+
type: openai_unet
|
3 |
+
args:
|
4 |
+
image_size: null # no use
|
5 |
+
in_channels: 4
|
6 |
+
out_channels: 4
|
7 |
+
model_channels: 320
|
8 |
+
attention_resolutions: [ 4, 2, 1 ]
|
9 |
+
num_res_blocks: [ 2, 2, 2, 2 ]
|
10 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
11 |
+
num_heads: 8
|
12 |
+
use_spatial_transformer: True
|
13 |
+
transformer_depth: 1
|
14 |
+
context_dim: 768
|
15 |
+
use_checkpoint: True
|
16 |
+
legacy: False
|
17 |
+
|
18 |
+
openai_unet_dual_context:
|
19 |
+
super_cfg: openai_unet_sd
|
20 |
+
type: openai_unet_dual_context
|
21 |
+
|
22 |
+
########################
|
23 |
+
# Code cleaned version #
|
24 |
+
########################
|
25 |
+
|
26 |
+
openai_unet_2d_audio:
|
27 |
+
type: openai_unet_2d
|
28 |
+
args:
|
29 |
+
input_channels: 8
|
30 |
+
model_channels: 192
|
31 |
+
output_channels: 8
|
32 |
+
num_noattn_blocks: [ 2, 2, 2, 2 ]
|
33 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
34 |
+
with_attn: [true, true, true, false]
|
35 |
+
channel_mult_connector: [1, 2, 4]
|
36 |
+
num_noattn_blocks_connector: [1, 1, 1]
|
37 |
+
with_connector: [True, True, True, False]
|
38 |
+
connector_output_channel: 1280
|
39 |
+
num_heads: 8
|
40 |
+
context_dim: 768
|
41 |
+
use_checkpoint: False
|
42 |
+
|
43 |
+
openai_unet_2d:
|
44 |
+
type: openai_unet_2d
|
45 |
+
args:
|
46 |
+
input_channels: 4
|
47 |
+
model_channels: 320
|
48 |
+
output_channels: 4
|
49 |
+
num_noattn_blocks: [ 2, 2, 2, 2 ]
|
50 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
51 |
+
with_attn: [true, true, true, false]
|
52 |
+
channel_mult_connector: [1, 2, 4]
|
53 |
+
num_noattn_blocks_connector: [1, 1, 1]
|
54 |
+
with_connector: [True, True, True, False]
|
55 |
+
connector_output_channel: 1280
|
56 |
+
num_heads: 8
|
57 |
+
context_dim: 768
|
58 |
+
use_checkpoint: True
|
59 |
+
use_video_architecture: True
|
60 |
+
|
61 |
+
openai_unet_0dmd:
|
62 |
+
type: openai_unet_0dmd
|
63 |
+
args:
|
64 |
+
input_channels: 768
|
65 |
+
model_channels: 320
|
66 |
+
output_channels: 768
|
67 |
+
num_noattn_blocks: [ 2, 2, 2, 2 ]
|
68 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
69 |
+
second_dim: [ 4, 4, 4, 4 ]
|
70 |
+
with_attn: [true, true, true, false]
|
71 |
+
num_noattn_blocks_connector: [1, 1, 1]
|
72 |
+
second_dim_connector: [4, 4, 4]
|
73 |
+
with_connector: [True, True, True, False]
|
74 |
+
connector_output_channel: 1280
|
75 |
+
num_heads: 8
|
76 |
+
context_dim: 768
|
77 |
+
use_checkpoint: True
|
78 |
+
|
79 |
+
openai_unet_codi:
|
80 |
+
type: openai_unet_codi
|
81 |
+
args:
|
82 |
+
unet_image_cfg: MODEL(openai_unet_2d)
|
83 |
+
unet_text_cfg: MODEL(openai_unet_0dmd)
|
84 |
+
unet_audio_cfg: MODEL(openai_unet_2d_audio)
|
85 |
+
model_type: ['video', 'image', 'text']
|
configs/model/optimus.yaml
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
optimus:
|
3 |
+
symbol: optimus
|
4 |
+
find_unused_parameters: false
|
5 |
+
args: {}
|
6 |
+
|
7 |
+
optimus_bert_encoder:
|
8 |
+
super_cfg: optimus
|
9 |
+
type: optimus_bert_connector
|
10 |
+
# pth: pretrained/optimus_bert_encoder.pth
|
11 |
+
args:
|
12 |
+
config:
|
13 |
+
architectures:
|
14 |
+
- BertForMaskedLM
|
15 |
+
attention_probs_dropout_prob: 0.1
|
16 |
+
finetuning_task: null
|
17 |
+
hidden_act: gelu
|
18 |
+
hidden_dropout_prob: 0.1
|
19 |
+
hidden_size: 768
|
20 |
+
initializer_range: 0.02
|
21 |
+
intermediate_size: 3072
|
22 |
+
layer_norm_eps: 1.e-12
|
23 |
+
max_position_embeddings: 512
|
24 |
+
num_attention_heads: 12
|
25 |
+
num_hidden_layers: 12
|
26 |
+
num_labels: 2
|
27 |
+
output_attentions: false
|
28 |
+
output_hidden_states: false
|
29 |
+
pruned_heads: {}
|
30 |
+
torchscript: false
|
31 |
+
type_vocab_size: 2
|
32 |
+
vocab_size: 28996
|
33 |
+
latent_size: 768
|
34 |
+
|
35 |
+
optimus_bert_tokenizer:
|
36 |
+
super_cfg: optimus
|
37 |
+
type: optimus_bert_tokenizer
|
38 |
+
args:
|
39 |
+
do_lower_case: false
|
40 |
+
max_len: 512
|
41 |
+
vocab_file: core/models/latent_diffusion/vae/optimus_modules/vocab/bert-base-cased-vocab.txt
|
42 |
+
|
43 |
+
optimus_gpt2_decoder:
|
44 |
+
super_cfg: optimus
|
45 |
+
type: optimus_gpt2_connector
|
46 |
+
# pth: pretrained/optimus_gpt2_decoder.pth
|
47 |
+
args:
|
48 |
+
config:
|
49 |
+
architectures:
|
50 |
+
- GPT2LMHeadModel
|
51 |
+
attn_pdrop: 0.1
|
52 |
+
embd_pdrop: 0.1
|
53 |
+
finetuning_task: null
|
54 |
+
hidden_size: 768
|
55 |
+
initializer_range: 0.02
|
56 |
+
latent_size: 768
|
57 |
+
layer_norm_epsilon: 1.e-05
|
58 |
+
max_position_embeddings: 1024
|
59 |
+
n_ctx: 1024
|
60 |
+
n_embd: 768
|
61 |
+
n_head: 12
|
62 |
+
n_layer: 12
|
63 |
+
n_positions: 1024
|
64 |
+
num_attention_heads: 12
|
65 |
+
num_hidden_layers: 12
|
66 |
+
num_labels: 1
|
67 |
+
output_attentions: false
|
68 |
+
output_hidden_states: false
|
69 |
+
pretrained_config_archive_map:
|
70 |
+
gpt2 : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json
|
71 |
+
gpt2-medium : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json
|
72 |
+
gpt2-large : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json
|
73 |
+
pruned_heads: {}
|
74 |
+
resid_pdrop: 0.1
|
75 |
+
summary_activation: null
|
76 |
+
summary_first_dropout: 0.1
|
77 |
+
summary_proj_to_labels: true
|
78 |
+
summary_type: cls_index
|
79 |
+
summary_use_proj: true
|
80 |
+
torchscript: false
|
81 |
+
vocab_size: 50260
|
82 |
+
|
83 |
+
optimus_gpt2_tokenizer:
|
84 |
+
super_cfg: optimus
|
85 |
+
type: optimus_gpt2_tokenizer
|
86 |
+
args:
|
87 |
+
do_lower_case: false
|
88 |
+
max_len: 1024
|
89 |
+
vocab_file: core/models/latent_diffusion/vae/optimus_modules/vocab/gpt2-vocab.json
|
90 |
+
merges_file: core/models/latent_diffusion/vae/optimus_modules/vocab/gpt2-merges.txt
|
91 |
+
|
92 |
+
optimus_vae:
|
93 |
+
super_cfg: optimus
|
94 |
+
type: optimus_vae
|
95 |
+
pth: pretrained/optimus-vae.pth
|
96 |
+
args:
|
97 |
+
encoder: MODEL(optimus_bert_encoder)
|
98 |
+
decoder: MODEL(optimus_gpt2_decoder)
|
99 |
+
tokenizer_encoder: MODEL(optimus_bert_tokenizer)
|
100 |
+
tokenizer_decoder: MODEL(optimus_gpt2_tokenizer)
|
101 |
+
args:
|
102 |
+
latent_size: 768
|
103 |
+
beta: 1.0
|
104 |
+
fb_mode: 0
|
105 |
+
length_weighted_loss: false
|
106 |
+
dim_target_kl : 3.0
|
107 |
+
|
configs/model/prova.yaml
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai_unet_sd:
|
2 |
+
type: openai_unet
|
3 |
+
args:
|
4 |
+
image_size: null # no use
|
5 |
+
in_channels: 4
|
6 |
+
out_channels: 4
|
7 |
+
model_channels: 320
|
8 |
+
attention_resolutions: [ 4, 2, 1 ]
|
9 |
+
num_res_blocks: [ 2, 2, 2, 2 ]
|
10 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
11 |
+
num_heads: 8
|
12 |
+
use_spatial_transformer: True
|
13 |
+
transformer_depth: 1
|
14 |
+
context_dim: 768
|
15 |
+
use_checkpoint: True
|
16 |
+
legacy: False
|
17 |
+
|
18 |
+
openai_unet_dual_context:
|
19 |
+
super_cfg: openai_unet_sd
|
20 |
+
type: openai_unet_dual_context
|
21 |
+
|
22 |
+
########################
|
23 |
+
# Code cleaned version #
|
24 |
+
########################
|
25 |
+
|
26 |
+
openai_unet_2d_audio:
|
27 |
+
type: openai_unet_2d
|
28 |
+
args:
|
29 |
+
input_channels: 8
|
30 |
+
model_channels: 192
|
31 |
+
output_channels: 8
|
32 |
+
num_noattn_blocks: [ 2, 2, 2, 2 ]
|
33 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
34 |
+
with_attn: [true, true, true, false]
|
35 |
+
channel_mult_connector: [1, 2, 4]
|
36 |
+
num_noattn_blocks_connector: [1, 1, 1]
|
37 |
+
with_connector: [True, True, True, False]
|
38 |
+
connector_output_channel: 1280
|
39 |
+
num_heads: 8
|
40 |
+
context_dim: 768
|
41 |
+
use_checkpoint: False
|
42 |
+
|
43 |
+
openai_unet_2d:
|
44 |
+
type: openai_unet_2d
|
45 |
+
args:
|
46 |
+
input_channels: 4
|
47 |
+
model_channels: 320
|
48 |
+
output_channels: 4
|
49 |
+
num_noattn_blocks: [ 2, 2, 2, 2 ]
|
50 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
51 |
+
with_attn: [true, true, true, false]
|
52 |
+
channel_mult_connector: [1, 2, 4]
|
53 |
+
num_noattn_blocks_connector: [1, 1, 1]
|
54 |
+
with_connector: [True, True, True, False]
|
55 |
+
connector_output_channel: 1280
|
56 |
+
num_heads: 8
|
57 |
+
context_dim: 768
|
58 |
+
use_checkpoint: True
|
59 |
+
use_video_architecture: True
|
60 |
+
|
61 |
+
openai_unet_0dmd:
|
62 |
+
type: openai_unet_0dmd
|
63 |
+
args:
|
64 |
+
input_channels: 768
|
65 |
+
model_channels: 320
|
66 |
+
output_channels: 768
|
67 |
+
num_noattn_blocks: [ 2, 2, 2, 2 ]
|
68 |
+
channel_mult: [ 1, 2, 4, 4 ]
|
69 |
+
second_dim: [ 4, 4, 4, 4 ]
|
70 |
+
with_attn: [true, true, true, false]
|
71 |
+
num_noattn_blocks_connector: [1, 1, 1]
|
72 |
+
second_dim_connector: [4, 4, 4]
|
73 |
+
with_connector: [True, True, True, False]
|
74 |
+
connector_output_channel: 1280
|
75 |
+
num_heads: 8
|
76 |
+
context_dim: 768
|
77 |
+
use_checkpoint: True
|
78 |
+
|
79 |
+
prova:
|
80 |
+
type: prova
|
81 |
+
args:
|
82 |
+
unet_frontal_cfg: MODEL(openai_unet_2d)
|
83 |
+
unet_lateral_cfg: MODEL(openai_unet_2d)
|
84 |
+
unet_text_cfg: MODEL(openai_unet_0dmd)
|
85 |
+
model_type: ['text']
|
configs/model/sd.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sd_autoencoder:
|
2 |
+
type: autoencoderkl
|
3 |
+
args:
|
4 |
+
embed_dim: 4
|
5 |
+
monitor: val/rec_loss
|
6 |
+
ddconfig:
|
7 |
+
double_z: true
|
8 |
+
z_channels: 4
|
9 |
+
resolution: 256
|
10 |
+
in_channels: 3
|
11 |
+
out_ch: 3
|
12 |
+
ch: 128
|
13 |
+
ch_mult: [1, 2, 4, 4]
|
14 |
+
num_res_blocks: 2
|
15 |
+
attn_resolutions: []
|
16 |
+
dropout: 0.0
|
17 |
+
use_video_arch: true
|
18 |
+
lossconfig:
|
19 |
+
target: torch.nn.Identity
|
20 |
+
pth: pretrained/kl-f8.pth
|
configs/model/thesis_model.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########
|
2 |
+
# CoDi #
|
3 |
+
########
|
4 |
+
|
5 |
+
thesis_model:
|
6 |
+
type: thesis_model
|
7 |
+
symbol: thesis_model
|
8 |
+
find_unused_parameters: true
|
9 |
+
args:
|
10 |
+
autokl_cfg: MODEL(sd_autoencoder)
|
11 |
+
optimus_cfg: MODEL(optimus_vae)
|
12 |
+
clip_cfg: MODEL(clip_frozen)
|
13 |
+
unet_config: MODEL(prova)
|
14 |
+
beta_linear_start: 0.00085
|
15 |
+
beta_linear_end: 0.012
|
16 |
+
timesteps: 1000
|
17 |
+
vision_scale_factor: 0.18215
|
18 |
+
text_scale_factor: 4.3108
|
19 |
+
audio_scale_factor: 0.9228
|
20 |
+
use_ema: false
|
21 |
+
parameterization : "eps"
|
core/__init__.py
ADDED
File without changes
|
core/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (149 Bytes). View file
|
|
core/__pycache__/cfg_helper.cpython-38.pyc
ADDED
Binary file (13 kB). View file
|
|
core/__pycache__/cfg_holder.cpython-38.pyc
ADDED
Binary file (1.21 kB). View file
|
|
core/__pycache__/sync.cpython-38.pyc
ADDED
Binary file (6.24 kB). View file
|
|
core/cfg_helper.py
ADDED
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import os.path as osp
|
3 |
+
import shutil
|
4 |
+
import copy
|
5 |
+
import time
|
6 |
+
import pprint
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import argparse
|
10 |
+
import json
|
11 |
+
import yaml
|
12 |
+
from easydict import EasyDict as edict
|
13 |
+
|
14 |
+
from core.models import get_model
|
15 |
+
|
16 |
+
############
|
17 |
+
# cfg_bank #
|
18 |
+
############
|
19 |
+
|
20 |
+
|
21 |
+
def cfg_solvef(cmd, root):
|
22 |
+
if not isinstance(cmd, str):
|
23 |
+
return cmd
|
24 |
+
|
25 |
+
if cmd.find('SAME')==0:
|
26 |
+
zoom = root
|
27 |
+
p = cmd[len('SAME'):].strip('()').split('.')
|
28 |
+
p = [pi.strip() for pi in p]
|
29 |
+
for pi in p:
|
30 |
+
try:
|
31 |
+
pi = int(pi)
|
32 |
+
except:
|
33 |
+
pass
|
34 |
+
|
35 |
+
try:
|
36 |
+
zoom = zoom[pi]
|
37 |
+
except:
|
38 |
+
return cmd
|
39 |
+
return cfg_solvef(zoom, root)
|
40 |
+
|
41 |
+
if cmd.find('SEARCH')==0:
|
42 |
+
zoom = root
|
43 |
+
p = cmd[len('SEARCH'):].strip('()').split('.')
|
44 |
+
p = [pi.strip() for pi in p]
|
45 |
+
find = True
|
46 |
+
# Depth first search
|
47 |
+
for pi in p:
|
48 |
+
try:
|
49 |
+
pi = int(pi)
|
50 |
+
except:
|
51 |
+
pass
|
52 |
+
|
53 |
+
try:
|
54 |
+
zoom = zoom[pi]
|
55 |
+
except:
|
56 |
+
find = False
|
57 |
+
break
|
58 |
+
|
59 |
+
if find:
|
60 |
+
return cfg_solvef(zoom, root)
|
61 |
+
else:
|
62 |
+
if isinstance(root, dict):
|
63 |
+
for ri in root:
|
64 |
+
rv = cfg_solvef(cmd, root[ri])
|
65 |
+
if rv != cmd:
|
66 |
+
return rv
|
67 |
+
if isinstance(root, list):
|
68 |
+
for ri in root:
|
69 |
+
rv = cfg_solvef(cmd, ri)
|
70 |
+
if rv != cmd:
|
71 |
+
return rv
|
72 |
+
return cmd
|
73 |
+
|
74 |
+
if cmd.find('MODEL')==0:
|
75 |
+
goto = cmd[len('MODEL'):].strip('()')
|
76 |
+
return model_cfg_bank()(goto)
|
77 |
+
|
78 |
+
if cmd.find('DATASET')==0:
|
79 |
+
goto = cmd[len('DATASET'):].strip('()')
|
80 |
+
return dataset_cfg_bank()(goto)
|
81 |
+
|
82 |
+
return cmd
|
83 |
+
|
84 |
+
|
85 |
+
def cfg_solve(cfg, cfg_root):
|
86 |
+
# The function solve cfg element such that
|
87 |
+
# all sorrogate input are settled.
|
88 |
+
# (i.e. SAME(***) )
|
89 |
+
if isinstance(cfg, list):
|
90 |
+
for i in range(len(cfg)):
|
91 |
+
if isinstance(cfg[i], (list, dict)):
|
92 |
+
cfg[i] = cfg_solve(cfg[i], cfg_root)
|
93 |
+
else:
|
94 |
+
cfg[i] = cfg_solvef(cfg[i], cfg_root)
|
95 |
+
if isinstance(cfg, dict):
|
96 |
+
for k in cfg:
|
97 |
+
if isinstance(cfg[k], (list, dict)):
|
98 |
+
cfg[k] = cfg_solve(cfg[k], cfg_root)
|
99 |
+
else:
|
100 |
+
cfg[k] = cfg_solvef(cfg[k], cfg_root)
|
101 |
+
return cfg
|
102 |
+
|
103 |
+
|
104 |
+
class model_cfg_bank(object):
|
105 |
+
def __init__(self):
|
106 |
+
self.cfg_dir = osp.join('configs', 'model')
|
107 |
+
self.cfg_bank = edict()
|
108 |
+
|
109 |
+
def __call__(self, name):
|
110 |
+
if name not in self.cfg_bank:
|
111 |
+
cfg_path = self.get_yaml_path(name)
|
112 |
+
with open(cfg_path, 'r') as f:
|
113 |
+
cfg_new = yaml.load(
|
114 |
+
f, Loader=yaml.FullLoader)
|
115 |
+
cfg_new = edict(cfg_new)
|
116 |
+
self.cfg_bank.update(cfg_new)
|
117 |
+
|
118 |
+
cfg = self.cfg_bank[name]
|
119 |
+
cfg.name = name
|
120 |
+
if 'super_cfg' not in cfg:
|
121 |
+
cfg = cfg_solve(cfg, cfg)
|
122 |
+
self.cfg_bank[name] = cfg
|
123 |
+
return copy.deepcopy(cfg)
|
124 |
+
|
125 |
+
super_cfg = self.__call__(cfg.super_cfg)
|
126 |
+
# unlike other field,
|
127 |
+
# args will not be replaced but update.
|
128 |
+
if 'args' in cfg:
|
129 |
+
if 'args' in super_cfg:
|
130 |
+
super_cfg.args.update(cfg.args)
|
131 |
+
else:
|
132 |
+
super_cfg.args = cfg.args
|
133 |
+
cfg.pop('args')
|
134 |
+
|
135 |
+
super_cfg.update(cfg)
|
136 |
+
super_cfg.pop('super_cfg')
|
137 |
+
cfg = super_cfg
|
138 |
+
try:
|
139 |
+
delete_args = cfg.pop('delete_args')
|
140 |
+
except:
|
141 |
+
delete_args = []
|
142 |
+
|
143 |
+
for dargs in delete_args:
|
144 |
+
cfg.args.pop(dargs)
|
145 |
+
|
146 |
+
cfg = cfg_solve(cfg, cfg)
|
147 |
+
self.cfg_bank[name] = cfg
|
148 |
+
return copy.deepcopy(cfg)
|
149 |
+
|
150 |
+
def get_yaml_path(self, name):
|
151 |
+
if name.find('openai_unet')==0:
|
152 |
+
return osp.join(
|
153 |
+
self.cfg_dir, 'openai_unet.yaml')
|
154 |
+
elif name.find('prova')==0:
|
155 |
+
return osp.join(
|
156 |
+
self.cfg_dir, 'prova.yaml')
|
157 |
+
elif name.find('audioldm')==0:
|
158 |
+
return osp.join(
|
159 |
+
self.cfg_dir, 'audioldm.yaml')
|
160 |
+
elif name.find('clip')==0:
|
161 |
+
return osp.join(
|
162 |
+
self.cfg_dir, 'clip.yaml')
|
163 |
+
elif name.find('sd')==0:
|
164 |
+
return osp.join(
|
165 |
+
self.cfg_dir, 'sd.yaml')
|
166 |
+
elif name.find('codi')==0:
|
167 |
+
return osp.join(
|
168 |
+
self.cfg_dir, 'codi.yaml')
|
169 |
+
elif name.find('thesis_model')==0:
|
170 |
+
return osp.join(
|
171 |
+
self.cfg_dir, 'thesis_model.yaml')
|
172 |
+
elif name.find('clap')==0:
|
173 |
+
return osp.join(
|
174 |
+
self.cfg_dir, 'clap.yaml')
|
175 |
+
elif name.find('optimus')==0:
|
176 |
+
return osp.join(
|
177 |
+
self.cfg_dir, 'optimus.yaml')
|
178 |
+
else:
|
179 |
+
raise ValueError
|
180 |
+
|
181 |
+
|
182 |
+
class dataset_cfg_bank(object):
|
183 |
+
def __init__(self):
|
184 |
+
self.cfg_dir = osp.join('configs', 'dataset')
|
185 |
+
self.cfg_bank = edict()
|
186 |
+
|
187 |
+
def __call__(self, name):
|
188 |
+
if name not in self.cfg_bank:
|
189 |
+
cfg_path = self.get_yaml_path(name)
|
190 |
+
with open(cfg_path, 'r') as f:
|
191 |
+
cfg_new = yaml.load(
|
192 |
+
f, Loader=yaml.FullLoader)
|
193 |
+
cfg_new = edict(cfg_new)
|
194 |
+
self.cfg_bank.update(cfg_new)
|
195 |
+
|
196 |
+
cfg = self.cfg_bank[name]
|
197 |
+
cfg.name = name
|
198 |
+
if cfg.get('super_cfg', None) is None:
|
199 |
+
cfg = cfg_solve(cfg, cfg)
|
200 |
+
self.cfg_bank[name] = cfg
|
201 |
+
return copy.deepcopy(cfg)
|
202 |
+
|
203 |
+
super_cfg = self.__call__(cfg.super_cfg)
|
204 |
+
super_cfg.update(cfg)
|
205 |
+
cfg = super_cfg
|
206 |
+
cfg.super_cfg = None
|
207 |
+
try:
|
208 |
+
delete = cfg.pop('delete')
|
209 |
+
except:
|
210 |
+
delete = []
|
211 |
+
|
212 |
+
for dargs in delete:
|
213 |
+
cfg.pop(dargs)
|
214 |
+
|
215 |
+
cfg = cfg_solve(cfg, cfg)
|
216 |
+
self.cfg_bank[name] = cfg
|
217 |
+
return copy.deepcopy(cfg)
|
218 |
+
|
219 |
+
def get_yaml_path(self, name):
|
220 |
+
if name.find('cityscapes')==0:
|
221 |
+
return osp.join(
|
222 |
+
self.cfg_dir, 'cityscapes.yaml')
|
223 |
+
elif name.find('div2k')==0:
|
224 |
+
return osp.join(
|
225 |
+
self.cfg_dir, 'div2k.yaml')
|
226 |
+
elif name.find('gandiv2k')==0:
|
227 |
+
return osp.join(
|
228 |
+
self.cfg_dir, 'gandiv2k.yaml')
|
229 |
+
elif name.find('srbenchmark')==0:
|
230 |
+
return osp.join(
|
231 |
+
self.cfg_dir, 'srbenchmark.yaml')
|
232 |
+
elif name.find('imagedir')==0:
|
233 |
+
return osp.join(
|
234 |
+
self.cfg_dir, 'imagedir.yaml')
|
235 |
+
elif name.find('places2')==0:
|
236 |
+
return osp.join(
|
237 |
+
self.cfg_dir, 'places2.yaml')
|
238 |
+
elif name.find('ffhq')==0:
|
239 |
+
return osp.join(
|
240 |
+
self.cfg_dir, 'ffhq.yaml')
|
241 |
+
elif name.find('imcpt')==0:
|
242 |
+
return osp.join(
|
243 |
+
self.cfg_dir, 'imcpt.yaml')
|
244 |
+
elif name.find('texture')==0:
|
245 |
+
return osp.join(
|
246 |
+
self.cfg_dir, 'texture.yaml')
|
247 |
+
elif name.find('openimages')==0:
|
248 |
+
return osp.join(
|
249 |
+
self.cfg_dir, 'openimages.yaml')
|
250 |
+
elif name.find('laion2b')==0:
|
251 |
+
return osp.join(
|
252 |
+
self.cfg_dir, 'laion2b.yaml')
|
253 |
+
elif name.find('laionart')==0:
|
254 |
+
return osp.join(
|
255 |
+
self.cfg_dir, 'laionart.yaml')
|
256 |
+
elif name.find('celeba')==0:
|
257 |
+
return osp.join(
|
258 |
+
self.cfg_dir, 'celeba.yaml')
|
259 |
+
elif name.find('coyo')==0:
|
260 |
+
return osp.join(
|
261 |
+
self.cfg_dir, 'coyo.yaml')
|
262 |
+
elif name.find('pafc')==0:
|
263 |
+
return osp.join(
|
264 |
+
self.cfg_dir, 'pafc.yaml')
|
265 |
+
elif name.find('coco')==0:
|
266 |
+
return osp.join(
|
267 |
+
self.cfg_dir, 'coco.yaml')
|
268 |
+
else:
|
269 |
+
raise ValueError
|
270 |
+
|
271 |
+
|
272 |
+
class experiment_cfg_bank(object):
|
273 |
+
def __init__(self):
|
274 |
+
self.cfg_dir = osp.join('configs', 'experiment')
|
275 |
+
self.cfg_bank = edict()
|
276 |
+
|
277 |
+
def __call__(self, name):
|
278 |
+
if name not in self.cfg_bank:
|
279 |
+
cfg_path = self.get_yaml_path(name)
|
280 |
+
with open(cfg_path, 'r') as f:
|
281 |
+
cfg = yaml.load(
|
282 |
+
f, Loader=yaml.FullLoader)
|
283 |
+
cfg = edict(cfg)
|
284 |
+
|
285 |
+
cfg = cfg_solve(cfg, cfg)
|
286 |
+
cfg = cfg_solve(cfg, cfg)
|
287 |
+
# twice for SEARCH
|
288 |
+
self.cfg_bank[name] = cfg
|
289 |
+
return copy.deepcopy(cfg)
|
290 |
+
|
291 |
+
def get_yaml_path(self, name):
|
292 |
+
return osp.join(
|
293 |
+
self.cfg_dir, name+'.yaml')
|
294 |
+
|
295 |
+
|
296 |
+
def load_cfg_yaml(path):
|
297 |
+
if osp.isfile(path):
|
298 |
+
cfg_path = path
|
299 |
+
elif osp.isfile(osp.join('configs', 'experiment', path)):
|
300 |
+
cfg_path = osp.join('configs', 'experiment', path)
|
301 |
+
elif osp.isfile(osp.join('configs', 'experiment', path+'.yaml')):
|
302 |
+
cfg_path = osp.join('configs', 'experiment', path+'.yaml')
|
303 |
+
else:
|
304 |
+
assert False, 'No such config!'
|
305 |
+
|
306 |
+
with open(cfg_path, 'r') as f:
|
307 |
+
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
308 |
+
cfg = edict(cfg)
|
309 |
+
cfg = cfg_solve(cfg, cfg)
|
310 |
+
cfg = cfg_solve(cfg, cfg)
|
311 |
+
return cfg
|
312 |
+
|
313 |
+
##############
|
314 |
+
# cfg_helper #
|
315 |
+
##############
|
316 |
+
|
317 |
+
|
318 |
+
def get_experiment_id(ref=None):
|
319 |
+
if ref is None:
|
320 |
+
time.sleep(0.5)
|
321 |
+
return int(time.time()*100)
|
322 |
+
else:
|
323 |
+
try:
|
324 |
+
return int(ref)
|
325 |
+
except:
|
326 |
+
pass
|
327 |
+
|
328 |
+
_, ref = osp.split(ref)
|
329 |
+
ref = ref.split('_')[0]
|
330 |
+
try:
|
331 |
+
return int(ref)
|
332 |
+
except:
|
333 |
+
assert False, 'Invalid experiment ID!'
|
334 |
+
|
335 |
+
|
336 |
+
def record_resume_cfg(path):
|
337 |
+
cnt = 0
|
338 |
+
while True:
|
339 |
+
if osp.exists(path+'.{:04d}'.format(cnt)):
|
340 |
+
cnt += 1
|
341 |
+
continue
|
342 |
+
shutil.copyfile(path, path+'.{:04d}'.format(cnt))
|
343 |
+
break
|
344 |
+
|
345 |
+
|
346 |
+
def get_command_line_args():
|
347 |
+
parser = argparse.ArgumentParser()
|
348 |
+
parser.add_argument('--debug', action='store_true', default=False)
|
349 |
+
parser.add_argument('--config', type=str)
|
350 |
+
parser.add_argument('--gpu', nargs='+', type=int)
|
351 |
+
|
352 |
+
parser.add_argument('--node_rank', type=int, default=0)
|
353 |
+
parser.add_argument('--nodes', type=int, default=1)
|
354 |
+
parser.add_argument('--addr', type=str, default='127.0.0.1')
|
355 |
+
parser.add_argument('--port', type=int, default=11233)
|
356 |
+
|
357 |
+
parser.add_argument('--signature', nargs='+', type=str)
|
358 |
+
parser.add_argument('--seed', type=int)
|
359 |
+
|
360 |
+
parser.add_argument('--eval', type=str)
|
361 |
+
parser.add_argument('--eval_subdir', type=str)
|
362 |
+
parser.add_argument('--pretrained', type=str)
|
363 |
+
|
364 |
+
parser.add_argument('--resume_dir', type=str)
|
365 |
+
parser.add_argument('--resume_step', type=int)
|
366 |
+
parser.add_argument('--resume_weight', type=str)
|
367 |
+
|
368 |
+
args = parser.parse_args()
|
369 |
+
|
370 |
+
# Special handling the resume
|
371 |
+
if args.resume_dir is not None:
|
372 |
+
cfg = edict()
|
373 |
+
cfg.env = edict()
|
374 |
+
cfg.env.debug = args.debug
|
375 |
+
cfg.env.resume = edict()
|
376 |
+
cfg.env.resume.dir = args.resume_dir
|
377 |
+
cfg.env.resume.step = args.resume_step
|
378 |
+
cfg.env.resume.weight = args.resume_weight
|
379 |
+
return cfg
|
380 |
+
|
381 |
+
cfg = load_cfg_yaml(args.config)
|
382 |
+
cfg.env.debug = args.debug
|
383 |
+
cfg.env.gpu_device = [0] if args.gpu is None else list(args.gpu)
|
384 |
+
cfg.env.master_addr = args.addr
|
385 |
+
cfg.env.master_port = args.port
|
386 |
+
cfg.env.dist_url = 'tcp://{}:{}'.format(args.addr, args.port)
|
387 |
+
cfg.env.node_rank = args.node_rank
|
388 |
+
cfg.env.nodes = args.nodes
|
389 |
+
|
390 |
+
istrain = False if args.eval is not None else True
|
391 |
+
isdebug = cfg.env.debug
|
392 |
+
|
393 |
+
if istrain:
|
394 |
+
if isdebug:
|
395 |
+
cfg.env.experiment_id = 999999999999
|
396 |
+
cfg.train.signature = ['debug']
|
397 |
+
else:
|
398 |
+
cfg.env.experiment_id = get_experiment_id()
|
399 |
+
if args.signature is not None:
|
400 |
+
cfg.train.signature = args.signature
|
401 |
+
else:
|
402 |
+
if 'train' in cfg:
|
403 |
+
cfg.pop('train')
|
404 |
+
cfg.env.experiment_id = get_experiment_id(args.eval)
|
405 |
+
if args.signature is not None:
|
406 |
+
cfg.eval.signature = args.signature
|
407 |
+
|
408 |
+
if isdebug and (args.eval is None):
|
409 |
+
cfg.env.experiment_id = 999999999999
|
410 |
+
cfg.eval.signature = ['debug']
|
411 |
+
|
412 |
+
if args.eval_subdir is not None:
|
413 |
+
if isdebug:
|
414 |
+
cfg.eval.eval_subdir = 'debug'
|
415 |
+
else:
|
416 |
+
cfg.eval.eval_subdir = args.eval_subdir
|
417 |
+
if args.pretrained is not None:
|
418 |
+
cfg.eval.pretrained = args.pretrained
|
419 |
+
# The override pretrained over the setting in cfg.model
|
420 |
+
if args.seed is not None:
|
421 |
+
cfg.env.rnd_seed = args.seed
|
422 |
+
return cfg
|
423 |
+
|
424 |
+
|
425 |
+
def cfg_initiates(cfg):
|
426 |
+
cfge = cfg.env
|
427 |
+
isdebug = cfge.debug
|
428 |
+
isresume = 'resume' in cfge
|
429 |
+
istrain = 'train' in cfg
|
430 |
+
haseval = 'eval' in cfg
|
431 |
+
cfgt = cfg.train if istrain else None
|
432 |
+
cfgv = cfg.eval if haseval else None
|
433 |
+
|
434 |
+
###############################
|
435 |
+
# get some environment params #
|
436 |
+
###############################
|
437 |
+
|
438 |
+
cfge.computer = os.uname()
|
439 |
+
cfge.torch_version = str(torch.__version__)
|
440 |
+
|
441 |
+
##########
|
442 |
+
# resume #
|
443 |
+
##########
|
444 |
+
|
445 |
+
if isresume:
|
446 |
+
resume_cfg_path = osp.join(cfge.resume.dir, 'config.yaml')
|
447 |
+
record_resume_cfg(resume_cfg_path)
|
448 |
+
with open(resume_cfg_path, 'r') as f:
|
449 |
+
cfg_resume = yaml.load(f, Loader=yaml.FullLoader)
|
450 |
+
cfg_resume = edict(cfg_resume)
|
451 |
+
cfg_resume.env.update(cfge)
|
452 |
+
cfg = cfg_resume
|
453 |
+
cfge = cfg.env
|
454 |
+
log_file = cfg.train.log_file
|
455 |
+
|
456 |
+
print('')
|
457 |
+
print('##########')
|
458 |
+
print('# resume #')
|
459 |
+
print('##########')
|
460 |
+
print('')
|
461 |
+
with open(log_file, 'a') as f:
|
462 |
+
print('', file=f)
|
463 |
+
print('##########', file=f)
|
464 |
+
print('# resume #', file=f)
|
465 |
+
print('##########', file=f)
|
466 |
+
print('', file=f)
|
467 |
+
|
468 |
+
pprint.pprint(cfg)
|
469 |
+
with open(log_file, 'a') as f:
|
470 |
+
pprint.pprint(cfg, f)
|
471 |
+
|
472 |
+
####################
|
473 |
+
# node distributed #
|
474 |
+
####################
|
475 |
+
|
476 |
+
if cfg.env.master_addr!='127.0.0.1':
|
477 |
+
os.environ['MASTER_ADDR'] = cfge.master_addr
|
478 |
+
os.environ['MASTER_PORT'] = '{}'.format(cfge.master_port)
|
479 |
+
if cfg.env.dist_backend=='nccl':
|
480 |
+
os.environ['NCCL_SOCKET_FAMILY'] = 'AF_INET'
|
481 |
+
if cfg.env.dist_backend=='gloo':
|
482 |
+
os.environ['GLOO_SOCKET_FAMILY'] = 'AF_INET'
|
483 |
+
|
484 |
+
#######################
|
485 |
+
# cuda visible device #
|
486 |
+
#######################
|
487 |
+
|
488 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
|
489 |
+
[str(gid) for gid in cfge.gpu_device])
|
490 |
+
|
491 |
+
#####################
|
492 |
+
# return resume cfg #
|
493 |
+
#####################
|
494 |
+
|
495 |
+
if isresume:
|
496 |
+
return cfg
|
497 |
+
|
498 |
+
#############################################
|
499 |
+
# some misc setting that not need in resume #
|
500 |
+
#############################################
|
501 |
+
|
502 |
+
cfgm = cfg.model
|
503 |
+
cfge.gpu_count = len(cfge.gpu_device)
|
504 |
+
|
505 |
+
##########################################
|
506 |
+
# align batch size and num worker config #
|
507 |
+
##########################################
|
508 |
+
|
509 |
+
gpu_n = cfge.gpu_count * cfge.nodes
|
510 |
+
|
511 |
+
def align_batch_size(bs, bs_per_gpu):
|
512 |
+
assert (bs is not None) or (bs_per_gpu is not None)
|
513 |
+
bs = bs_per_gpu * gpu_n if bs is None else bs
|
514 |
+
bs_per_gpu = bs // gpu_n if bs_per_gpu is None else bs_per_gpu
|
515 |
+
assert (bs == bs_per_gpu * gpu_n)
|
516 |
+
return bs, bs_per_gpu
|
517 |
+
|
518 |
+
if istrain:
|
519 |
+
cfgt.batch_size, cfgt.batch_size_per_gpu = \
|
520 |
+
align_batch_size(cfgt.batch_size, cfgt.batch_size_per_gpu)
|
521 |
+
cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu = \
|
522 |
+
align_batch_size(cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu)
|
523 |
+
if haseval:
|
524 |
+
cfgv.batch_size, cfgv.batch_size_per_gpu = \
|
525 |
+
align_batch_size(cfgv.batch_size, cfgv.batch_size_per_gpu)
|
526 |
+
cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu = \
|
527 |
+
align_batch_size(cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu)
|
528 |
+
|
529 |
+
##################
|
530 |
+
# create log dir #
|
531 |
+
##################
|
532 |
+
|
533 |
+
if istrain:
|
534 |
+
if not isdebug:
|
535 |
+
sig = cfgt.get('signature', [])
|
536 |
+
version = get_model().get_version(cfgm.type)
|
537 |
+
sig = sig + ['v{}'.format(version), 's{}'.format(cfge.rnd_seed)]
|
538 |
+
else:
|
539 |
+
sig = ['debug']
|
540 |
+
|
541 |
+
log_dir = [
|
542 |
+
cfge.log_root_dir,
|
543 |
+
'{}_{}'.format(cfgm.symbol, cfgt.dataset.symbol),
|
544 |
+
'_'.join([str(cfge.experiment_id)] + sig)
|
545 |
+
]
|
546 |
+
log_dir = osp.join(*log_dir)
|
547 |
+
log_file = osp.join(log_dir, 'train.log')
|
548 |
+
if not osp.exists(log_file):
|
549 |
+
os.makedirs(osp.dirname(log_file))
|
550 |
+
cfgt.log_dir = log_dir
|
551 |
+
cfgt.log_file = log_file
|
552 |
+
|
553 |
+
if haseval:
|
554 |
+
cfgv.log_dir = log_dir
|
555 |
+
cfgv.log_file = log_file
|
556 |
+
else:
|
557 |
+
model_symbol = cfgm.symbol
|
558 |
+
if cfgv.get('dataset', None) is None:
|
559 |
+
dataset_symbol = 'nodataset'
|
560 |
+
else:
|
561 |
+
dataset_symbol = cfgv.dataset.symbol
|
562 |
+
|
563 |
+
log_dir = osp.join(cfge.log_root_dir, '{}_{}'.format(model_symbol, dataset_symbol))
|
564 |
+
exp_dir = search_experiment_folder(log_dir, cfge.experiment_id)
|
565 |
+
if exp_dir is None:
|
566 |
+
if not isdebug:
|
567 |
+
sig = cfgv.get('signature', []) + ['evalonly']
|
568 |
+
else:
|
569 |
+
sig = ['debug']
|
570 |
+
exp_dir = '_'.join([str(cfge.experiment_id)] + sig)
|
571 |
+
|
572 |
+
eval_subdir = cfgv.get('eval_subdir', None)
|
573 |
+
# override subdir in debug mode (if eval_subdir is set)
|
574 |
+
eval_subdir = 'debug' if (eval_subdir is not None) and isdebug else eval_subdir
|
575 |
+
|
576 |
+
if eval_subdir is not None:
|
577 |
+
log_dir = osp.join(log_dir, exp_dir, eval_subdir)
|
578 |
+
else:
|
579 |
+
log_dir = osp.join(log_dir, exp_dir)
|
580 |
+
|
581 |
+
disable_log_override = cfgv.get('disable_log_override', False)
|
582 |
+
if osp.isdir(log_dir):
|
583 |
+
if disable_log_override:
|
584 |
+
assert False, 'Override an exsited log_dir is disabled at [{}]'.format(log_dir)
|
585 |
+
else:
|
586 |
+
os.makedirs(log_dir)
|
587 |
+
|
588 |
+
log_file = osp.join(log_dir, 'eval.log')
|
589 |
+
cfgv.log_dir = log_dir
|
590 |
+
cfgv.log_file = log_file
|
591 |
+
|
592 |
+
######################
|
593 |
+
# print and save cfg #
|
594 |
+
######################
|
595 |
+
|
596 |
+
pprint.pprint(cfg)
|
597 |
+
with open(log_file, 'w') as f:
|
598 |
+
pprint.pprint(cfg, f)
|
599 |
+
with open(osp.join(log_dir, 'config.yaml'), 'w') as f:
|
600 |
+
yaml.dump(edict_2_dict(cfg), f)
|
601 |
+
|
602 |
+
#############
|
603 |
+
# save code #
|
604 |
+
#############
|
605 |
+
|
606 |
+
save_code = False
|
607 |
+
if istrain:
|
608 |
+
save_code = cfgt.get('save_code', False)
|
609 |
+
elif haseval:
|
610 |
+
save_code = cfgv.get('save_code', False)
|
611 |
+
|
612 |
+
if save_code:
|
613 |
+
codedir = osp.join(log_dir, 'code')
|
614 |
+
if osp.exists(codedir):
|
615 |
+
shutil.rmtree(codedir)
|
616 |
+
for d in ['configs', 'lib']:
|
617 |
+
fromcodedir = d
|
618 |
+
tocodedir = osp.join(codedir, d)
|
619 |
+
shutil.copytree(
|
620 |
+
fromcodedir, tocodedir,
|
621 |
+
ignore=shutil.ignore_patterns(
|
622 |
+
'*__pycache__*', '*build*'))
|
623 |
+
for codei in os.listdir('.'):
|
624 |
+
if osp.splitext(codei)[1] == 'py':
|
625 |
+
shutil.copy(codei, codedir)
|
626 |
+
|
627 |
+
#######################
|
628 |
+
# set matplotlib mode #
|
629 |
+
#######################
|
630 |
+
|
631 |
+
if 'matplotlib_mode' in cfge:
|
632 |
+
try:
|
633 |
+
matplotlib.use(cfge.matplotlib_mode)
|
634 |
+
except:
|
635 |
+
print('Warning: matplotlib mode [{}] failed to be set!'.format(cfge.matplotlib_mode))
|
636 |
+
|
637 |
+
return cfg
|
638 |
+
|
639 |
+
|
640 |
+
def edict_2_dict(x):
|
641 |
+
if isinstance(x, dict):
|
642 |
+
xnew = {}
|
643 |
+
for k in x:
|
644 |
+
xnew[k] = edict_2_dict(x[k])
|
645 |
+
return xnew
|
646 |
+
elif isinstance(x, list):
|
647 |
+
xnew = []
|
648 |
+
for i in range(len(x)):
|
649 |
+
xnew.append( edict_2_dict(x[i]) )
|
650 |
+
return xnew
|
651 |
+
else:
|
652 |
+
return x
|
653 |
+
|
654 |
+
|
655 |
+
def search_experiment_folder(root, exid):
|
656 |
+
target = None
|
657 |
+
for fi in os.listdir(root):
|
658 |
+
if not osp.isdir(osp.join(root, fi)):
|
659 |
+
continue
|
660 |
+
if int(fi.split('_')[0]) == exid:
|
661 |
+
if target is not None:
|
662 |
+
return None # duplicated
|
663 |
+
elif target is None:
|
664 |
+
target = fi
|
665 |
+
return target
|
core/cfg_holder.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
|
4 |
+
def singleton(class_):
|
5 |
+
instances = {}
|
6 |
+
|
7 |
+
def getinstance(*args, **kwargs):
|
8 |
+
if class_ not in instances:
|
9 |
+
instances[class_] = class_(*args, **kwargs)
|
10 |
+
return instances[class_]
|
11 |
+
return getinstance
|
12 |
+
|
13 |
+
##############
|
14 |
+
# cfg_holder #
|
15 |
+
##############
|
16 |
+
|
17 |
+
|
18 |
+
@singleton
|
19 |
+
class cfg_unique_holder(object):
|
20 |
+
def __init__(self):
|
21 |
+
self.cfg = None
|
22 |
+
# this is use to track the main codes.
|
23 |
+
self.code = set()
|
24 |
+
|
25 |
+
def save_cfg(self, cfg):
|
26 |
+
self.cfg = copy.deepcopy(cfg)
|
27 |
+
|
28 |
+
def add_code(self, code):
|
29 |
+
"""
|
30 |
+
A new main code is reached and
|
31 |
+
its name is added.
|
32 |
+
"""
|
33 |
+
self.code.add(code)
|
core/common/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (11.3 kB). View file
|
|
core/common/registry.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from argparse import Namespace
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
from hydra.core.config_store import ConfigStore
|
10 |
+
from omegaconf import DictConfig
|
11 |
+
|
12 |
+
REGISTRIES = {}
|
13 |
+
|
14 |
+
|
15 |
+
def setup_registry(registry_name: str,
|
16 |
+
base_class=None,
|
17 |
+
default=None,
|
18 |
+
required=False):
|
19 |
+
assert registry_name.startswith('--')
|
20 |
+
registry_name = registry_name[2:].replace('-', '_')
|
21 |
+
|
22 |
+
REGISTRY = {}
|
23 |
+
REGISTRY_CLASS_NAMES = set()
|
24 |
+
DATACLASS_REGISTRY = {}
|
25 |
+
|
26 |
+
# maintain a registry of all registries
|
27 |
+
if registry_name in REGISTRIES:
|
28 |
+
return # registry already exists
|
29 |
+
REGISTRIES[registry_name] = {
|
30 |
+
'registry': REGISTRY,
|
31 |
+
'default': default,
|
32 |
+
'dataclass_registry': DATACLASS_REGISTRY,
|
33 |
+
}
|
34 |
+
|
35 |
+
def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args,
|
36 |
+
**extra_kwargs):
|
37 |
+
|
38 |
+
assert isinstance(cfg, str)
|
39 |
+
choice = cfg
|
40 |
+
if choice in DATACLASS_REGISTRY:
|
41 |
+
cfg = DATACLASS_REGISTRY[choice]()
|
42 |
+
|
43 |
+
if choice is None:
|
44 |
+
if required:
|
45 |
+
raise ValueError('{} is required!'.format(registry_name))
|
46 |
+
return None
|
47 |
+
|
48 |
+
cls = REGISTRY[choice]
|
49 |
+
if hasattr(cls, 'build_' + registry_name):
|
50 |
+
builder = getattr(cls, 'build_' + registry_name)
|
51 |
+
else:
|
52 |
+
builder = cls
|
53 |
+
return builder(cfg, *extra_args, **extra_kwargs)
|
54 |
+
|
55 |
+
def register_x(name, dataclass=None):
|
56 |
+
def register_x_cls(cls):
|
57 |
+
if name in REGISTRY:
|
58 |
+
raise ValueError('Cannot register duplicate {} ({})'.format(
|
59 |
+
registry_name, name))
|
60 |
+
if cls.__name__ in REGISTRY_CLASS_NAMES:
|
61 |
+
raise ValueError(
|
62 |
+
'Cannot register {} with duplicate class name ({})'.format(
|
63 |
+
registry_name, cls.__name__))
|
64 |
+
if base_class is not None and not issubclass(cls, base_class):
|
65 |
+
raise ValueError('{} must extend {}'.format(
|
66 |
+
cls.__name__, base_class.__name__))
|
67 |
+
|
68 |
+
cls.__dataclass = dataclass
|
69 |
+
if cls.__dataclass is not None:
|
70 |
+
DATACLASS_REGISTRY[name] = cls.__dataclass
|
71 |
+
|
72 |
+
cs = ConfigStore.instance()
|
73 |
+
node = dataclass()
|
74 |
+
node._name = name
|
75 |
+
cs.store(name=name,
|
76 |
+
group=registry_name,
|
77 |
+
node=node,
|
78 |
+
provider='layoutlmft')
|
79 |
+
|
80 |
+
REGISTRY[name] = cls
|
81 |
+
|
82 |
+
return cls
|
83 |
+
|
84 |
+
return register_x_cls
|
85 |
+
|
86 |
+
return build_x, register_x, REGISTRY, DATACLASS_REGISTRY
|
core/common/utils.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import torchvision.transforms as T
|
8 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor
|
9 |
+
from torchvision import transforms as tvtrans
|
10 |
+
|
11 |
+
from decord import VideoReader, cpu, gpu
|
12 |
+
|
13 |
+
|
14 |
+
###############
|
15 |
+
# text helper #
|
16 |
+
###############
|
17 |
+
|
18 |
+
|
19 |
+
def remove_duplicate_word(tx):
|
20 |
+
def combine_words(input, length):
|
21 |
+
combined_inputs = []
|
22 |
+
if len(splitted_input) > 1:
|
23 |
+
for i in range(len(input) - 1):
|
24 |
+
combined_inputs.append(input[i] + " " + last_word_of(splitted_input[i + 1],
|
25 |
+
length)) # add the last word of the right-neighbour (overlapping) sequence (before it has expanded), which is the next word in the original sentence
|
26 |
+
return combined_inputs, length + 1
|
27 |
+
|
28 |
+
def remove_duplicates(input, length):
|
29 |
+
bool_broke = False #this means we didn't find any duplicates here
|
30 |
+
for i in range(len(input) - length):
|
31 |
+
if input[i] == input[i + length]: #found a duplicate piece of sentence!
|
32 |
+
for j in range(0, length): #remove the overlapping sequences in reverse order
|
33 |
+
del input[i + length - j]
|
34 |
+
bool_broke = True
|
35 |
+
break #break the for loop as the loop length does not matches the length of splitted_input anymore as we removed elements
|
36 |
+
if bool_broke:
|
37 |
+
return remove_duplicates(input,
|
38 |
+
length) #if we found a duplicate, look for another duplicate of the same length
|
39 |
+
return input
|
40 |
+
|
41 |
+
def last_word_of(input, length):
|
42 |
+
splitted = input.split(" ")
|
43 |
+
if len(splitted) == 0:
|
44 |
+
return input
|
45 |
+
else:
|
46 |
+
return splitted[length - 1]
|
47 |
+
|
48 |
+
def split_and_puncsplit(text):
|
49 |
+
tx = text.split(" ")
|
50 |
+
txnew = []
|
51 |
+
for txi in tx:
|
52 |
+
txqueue = []
|
53 |
+
while True:
|
54 |
+
if txi[0] in '([{':
|
55 |
+
txqueue.extend([txi[:1], '<puncnext>'])
|
56 |
+
txi = txi[1:]
|
57 |
+
if len(txi) == 0:
|
58 |
+
break
|
59 |
+
else:
|
60 |
+
break
|
61 |
+
txnew += txqueue
|
62 |
+
txstack = []
|
63 |
+
if len(txi) == 0:
|
64 |
+
continue
|
65 |
+
while True:
|
66 |
+
if txi[-1] in '?!.,:;}])':
|
67 |
+
txstack = ['<puncnext>', txi[-1:]] + txstack
|
68 |
+
txi = txi[:-1]
|
69 |
+
if len(txi) == 0:
|
70 |
+
break
|
71 |
+
else:
|
72 |
+
break
|
73 |
+
if len(txi) != 0:
|
74 |
+
txnew += [txi]
|
75 |
+
txnew += txstack
|
76 |
+
return txnew
|
77 |
+
|
78 |
+
if tx == '':
|
79 |
+
return tx
|
80 |
+
|
81 |
+
splitted_input = split_and_puncsplit(tx)
|
82 |
+
word_length = 1
|
83 |
+
intermediate_output = False
|
84 |
+
while len(splitted_input) > 1:
|
85 |
+
splitted_input = remove_duplicates(splitted_input, word_length)
|
86 |
+
if len(splitted_input) > 1:
|
87 |
+
splitted_input, word_length = combine_words(splitted_input, word_length)
|
88 |
+
if intermediate_output:
|
89 |
+
print(splitted_input)
|
90 |
+
print(word_length)
|
91 |
+
output = splitted_input[0]
|
92 |
+
output = output.replace(' <puncnext> ', '')
|
93 |
+
return output
|
94 |
+
|
95 |
+
|
96 |
+
#################
|
97 |
+
# vision helper #
|
98 |
+
#################
|
99 |
+
|
100 |
+
|
101 |
+
def regularize_image(x, image_size=512):
|
102 |
+
BICUBIC = T.InterpolationMode.BICUBIC
|
103 |
+
if isinstance(x, str):
|
104 |
+
x = Image.open(x)
|
105 |
+
size = min(x.size)
|
106 |
+
elif isinstance(x, Image.Image):
|
107 |
+
x = x.convert('RGB')
|
108 |
+
size = min(x.size)
|
109 |
+
elif isinstance(x, np.ndarray):
|
110 |
+
x = Image.fromarray(x).convert('RGB')
|
111 |
+
size = min(x.size)
|
112 |
+
elif isinstance(x, torch.Tensor):
|
113 |
+
# normalize to [0, 1]
|
114 |
+
x = x/255.0
|
115 |
+
size = min(x.size()[1:])
|
116 |
+
else:
|
117 |
+
assert False, 'Unknown image type'
|
118 |
+
|
119 |
+
"""transforms = T.Compose([
|
120 |
+
T.RandomCrop(size),
|
121 |
+
T.Resize(
|
122 |
+
(image_size, image_size),
|
123 |
+
interpolation=BICUBIC,
|
124 |
+
),
|
125 |
+
T.RandomHorizontalFlip(),
|
126 |
+
T.ToTensor(),
|
127 |
+
])
|
128 |
+
x = transforms(x)
|
129 |
+
|
130 |
+
assert (x.shape[1] == image_size) & (x.shape[2] == image_size), \
|
131 |
+
'Wrong image size'
|
132 |
+
"""
|
133 |
+
x = x * 2 - 1
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
def center_crop(img, new_width=None, new_height=None):
|
138 |
+
width = img.shape[2]
|
139 |
+
height = img.shape[1]
|
140 |
+
|
141 |
+
if new_width is None:
|
142 |
+
new_width = min(width, height)
|
143 |
+
|
144 |
+
if new_height is None:
|
145 |
+
new_height = min(width, height)
|
146 |
+
|
147 |
+
left = int(np.ceil((width - new_width) / 2))
|
148 |
+
right = width - int(np.floor((width - new_width) / 2))
|
149 |
+
|
150 |
+
top = int(np.ceil((height - new_height) / 2))
|
151 |
+
bottom = height - int(np.floor((height - new_height) / 2))
|
152 |
+
if len(img.shape) == 3:
|
153 |
+
center_cropped_img = img[:, top:bottom, left:right]
|
154 |
+
else:
|
155 |
+
center_cropped_img = img[:, top:bottom, left:right, ...]
|
156 |
+
|
157 |
+
return center_cropped_img
|
158 |
+
|
159 |
+
|
160 |
+
def _transform(n_px):
|
161 |
+
return Compose([
|
162 |
+
Resize([n_px, n_px], interpolation=T.InterpolationMode.BICUBIC), ])
|
163 |
+
|
164 |
+
|
165 |
+
def regularize_video(video, image_size=256):
|
166 |
+
min_shape = min(video.shape[1:3])
|
167 |
+
video = center_crop(video, min_shape, min_shape)
|
168 |
+
video = torch.from_numpy(video).permute(0, 3, 1, 2)
|
169 |
+
video = _transform(image_size)(video)
|
170 |
+
video = video / 255.0 * 2.0 - 1.0
|
171 |
+
return video.permute(1, 0, 2, 3)
|
172 |
+
|
173 |
+
|
174 |
+
def time_to_indices(video_reader, time):
|
175 |
+
times = video_reader.get_frame_timestamp(range(len(video_reader))).mean(-1)
|
176 |
+
indices = np.searchsorted(times, time)
|
177 |
+
# Use `np.bitwise_or` so it works both with scalars and numpy arrays.
|
178 |
+
return np.where(np.bitwise_or(indices == 0, times[indices] - time <= time - times[indices - 1]), indices,
|
179 |
+
indices - 1)
|
180 |
+
|
181 |
+
|
182 |
+
def load_video(video_path, sample_duration=8.0, num_frames=8):
|
183 |
+
sample_duration = 4.0
|
184 |
+
num_frames = 4
|
185 |
+
|
186 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
187 |
+
framerate = vr.get_avg_fps()
|
188 |
+
video_frame_len = len(vr)
|
189 |
+
video_len = video_frame_len / framerate
|
190 |
+
sample_duration = min(sample_duration, video_len)
|
191 |
+
|
192 |
+
if video_len > sample_duration:
|
193 |
+
s = random.random() * (video_len - sample_duration)
|
194 |
+
t = s + sample_duration
|
195 |
+
start, end = time_to_indices(vr, [s, t])
|
196 |
+
end = min(video_frame_len - 1, end)
|
197 |
+
start = min(start, end - 1)
|
198 |
+
downsamlp_indices = np.linspace(start, end, num_frames, endpoint=True).astype(int).tolist()
|
199 |
+
else:
|
200 |
+
downsamlp_indices = np.linspace(0, video_frame_len - 1, num_frames, endpoint=True).astype(int).tolist()
|
201 |
+
|
202 |
+
video = vr.get_batch(downsamlp_indices).asnumpy()
|
203 |
+
return video
|
204 |
+
|
205 |
+
|
206 |
+
###############
|
207 |
+
# some helper #
|
208 |
+
###############
|
209 |
+
|
210 |
+
def atomic_save(cfg, net, opt, step, path):
|
211 |
+
if isinstance(net, (torch.nn.DataParallel,
|
212 |
+
torch.nn.parallel.DistributedDataParallel)):
|
213 |
+
netm = net.module
|
214 |
+
else:
|
215 |
+
netm = net
|
216 |
+
sd = netm.state_dict()
|
217 |
+
slimmed_sd = [(ki, vi) for ki, vi in sd.items()
|
218 |
+
if ki.find('first_stage_model') != 0 and ki.find('cond_stage_model') != 0]
|
219 |
+
|
220 |
+
checkpoint = {
|
221 |
+
"config": cfg,
|
222 |
+
"state_dict": OrderedDict(slimmed_sd),
|
223 |
+
"step": step}
|
224 |
+
if opt is not None:
|
225 |
+
checkpoint['optimizer_states'] = opt.state_dict()
|
226 |
+
import io
|
227 |
+
import fsspec
|
228 |
+
bytesbuffer = io.BytesIO()
|
229 |
+
torch.save(checkpoint, bytesbuffer)
|
230 |
+
with fsspec.open(path, "wb") as f:
|
231 |
+
f.write(bytesbuffer.getvalue())
|
232 |
+
|
233 |
+
|
234 |
+
def load_state_dict(net, cfg):
|
235 |
+
pretrained_pth_full = cfg.get('pretrained_pth_full', None)
|
236 |
+
pretrained_ckpt_full = cfg.get('pretrained_ckpt_full', None)
|
237 |
+
pretrained_pth = cfg.get('pretrained_pth', None)
|
238 |
+
pretrained_ckpt = cfg.get('pretrained_ckpt', None)
|
239 |
+
pretrained_pth_dm = cfg.get('pretrained_pth_dm', None)
|
240 |
+
pretrained_pth_ema = cfg.get('pretrained_pth_ema', None)
|
241 |
+
strict_sd = cfg.get('strict_sd', False)
|
242 |
+
errmsg = "Overlapped model state_dict! This is undesired behavior!"
|
243 |
+
|
244 |
+
if pretrained_pth_full is not None or pretrained_ckpt_full is not None:
|
245 |
+
assert (pretrained_pth is None) and \
|
246 |
+
(pretrained_ckpt is None) and \
|
247 |
+
(pretrained_pth_dm is None) and \
|
248 |
+
(pretrained_pth_ema is None), errmsg
|
249 |
+
if pretrained_pth_full is not None:
|
250 |
+
target_file = pretrained_pth_full
|
251 |
+
sd = torch.load(target_file, map_location='cpu')
|
252 |
+
assert pretrained_ckpt is None, errmsg
|
253 |
+
else:
|
254 |
+
target_file = pretrained_ckpt_full
|
255 |
+
sd = torch.load(target_file, map_location='cpu')['state_dict']
|
256 |
+
print('Load full model from [{}] strict [{}].'.format(
|
257 |
+
target_file, strict_sd))
|
258 |
+
net.load_state_dict(sd, strict=strict_sd)
|
259 |
+
|
260 |
+
if pretrained_pth is not None or pretrained_ckpt is not None:
|
261 |
+
assert (pretrained_ckpt_full is None) and \
|
262 |
+
(pretrained_pth_full is None) and \
|
263 |
+
(pretrained_pth_dm is None) and \
|
264 |
+
(pretrained_pth_ema is None), errmsg
|
265 |
+
if pretrained_pth is not None:
|
266 |
+
target_file = pretrained_pth
|
267 |
+
sd = torch.load(target_file, map_location='cpu')
|
268 |
+
assert pretrained_ckpt is None, errmsg
|
269 |
+
else:
|
270 |
+
target_file = pretrained_ckpt
|
271 |
+
sd = torch.load(target_file, map_location='cpu')['state_dict']
|
272 |
+
print('Load model from [{}] strict [{}].'.format(
|
273 |
+
target_file, strict_sd))
|
274 |
+
sd_extra = [(ki, vi) for ki, vi in net.state_dict().items() \
|
275 |
+
if ki.find('first_stage_model') == 0 or ki.find('cond_stage_model') == 0]
|
276 |
+
sd.update(OrderedDict(sd_extra))
|
277 |
+
net.load_state_dict(sd, strict=strict_sd)
|
278 |
+
|
279 |
+
if pretrained_pth_dm is not None:
|
280 |
+
assert (pretrained_ckpt_full is None) and \
|
281 |
+
(pretrained_pth_full is None) and \
|
282 |
+
(pretrained_pth is None) and \
|
283 |
+
(pretrained_ckpt is None), errmsg
|
284 |
+
print('Load diffusion model from [{}] strict [{}].'.format(
|
285 |
+
pretrained_pth_dm, strict_sd))
|
286 |
+
sd = torch.load(pretrained_pth_dm, map_location='cpu')
|
287 |
+
net.model.diffusion_model.load_state_dict(sd, strict=strict_sd)
|
288 |
+
|
289 |
+
if pretrained_pth_ema is not None:
|
290 |
+
assert (pretrained_ckpt_full is None) and \
|
291 |
+
(pretrained_pth_full is None) and \
|
292 |
+
(pretrained_pth is None) and \
|
293 |
+
(pretrained_ckpt is None), errmsg
|
294 |
+
print('Load unet ema model from [{}] strict [{}].'.format(
|
295 |
+
pretrained_pth_ema, strict_sd))
|
296 |
+
sd = torch.load(pretrained_pth_ema, map_location='cpu')
|
297 |
+
net.model_ema.load_state_dict(sd, strict=strict_sd)
|
298 |
+
|
299 |
+
|
300 |
+
def auto_merge_imlist(imlist, max=64):
|
301 |
+
imlist = imlist[0:max]
|
302 |
+
h, w = imlist[0].shape[0:2]
|
303 |
+
num_images = len(imlist)
|
304 |
+
num_row = int(np.sqrt(num_images))
|
305 |
+
num_col = num_images // num_row + 1 if num_images % num_row != 0 else num_images // num_row
|
306 |
+
canvas = np.zeros([num_row * h, num_col * w, 3], dtype=np.uint8)
|
307 |
+
for idx, im in enumerate(imlist):
|
308 |
+
hi = (idx // num_col) * h
|
309 |
+
wi = (idx % num_col) * w
|
310 |
+
canvas[hi:hi + h, wi:wi + w, :] = im
|
311 |
+
return canvas
|
312 |
+
|
313 |
+
|
314 |
+
def latent2im(net, latent):
|
315 |
+
single_input = len(latent.shape) == 3
|
316 |
+
if single_input:
|
317 |
+
latent = latent[None]
|
318 |
+
im = net.decode_image(latent.to(net.device))
|
319 |
+
im = torch.clamp((im + 1.0) / 2.0, min=0.0, max=1.0)
|
320 |
+
im = [tvtrans.ToPILImage()(i) for i in im]
|
321 |
+
if single_input:
|
322 |
+
im = im[0]
|
323 |
+
return im
|
324 |
+
|
325 |
+
|
326 |
+
def im2latent(net, im):
|
327 |
+
single_input = not isinstance(im, list)
|
328 |
+
if single_input:
|
329 |
+
im = [im]
|
330 |
+
im = torch.stack([tvtrans.ToTensor()(i) for i in im], dim=0)
|
331 |
+
im = (im * 2 - 1).to(net.device)
|
332 |
+
z = net.encode_image(im)
|
333 |
+
if single_input:
|
334 |
+
z = z[0]
|
335 |
+
return z
|
336 |
+
|
337 |
+
|
338 |
+
class color_adjust(object):
|
339 |
+
def __init__(self, ref_from, ref_to):
|
340 |
+
x0, m0, std0 = self.get_data_and_stat(ref_from)
|
341 |
+
x1, m1, std1 = self.get_data_and_stat(ref_to)
|
342 |
+
self.ref_from_stat = (m0, std0)
|
343 |
+
self.ref_to_stat = (m1, std1)
|
344 |
+
self.ref_from = self.preprocess(x0).reshape(-1, 3)
|
345 |
+
self.ref_to = x1.reshape(-1, 3)
|
346 |
+
|
347 |
+
def get_data_and_stat(self, x):
|
348 |
+
if isinstance(x, str):
|
349 |
+
x = np.array(PIL.Image.open(x))
|
350 |
+
elif isinstance(x, PIL.Image.Image):
|
351 |
+
x = np.array(x)
|
352 |
+
elif isinstance(x, torch.Tensor):
|
353 |
+
x = torch.clamp(x, min=0.0, max=1.0)
|
354 |
+
x = np.array(tvtrans.ToPILImage()(x))
|
355 |
+
elif isinstance(x, np.ndarray):
|
356 |
+
pass
|
357 |
+
else:
|
358 |
+
raise ValueError
|
359 |
+
x = x.astype(float)
|
360 |
+
m = np.reshape(x, (-1, 3)).mean(0)
|
361 |
+
s = np.reshape(x, (-1, 3)).std(0)
|
362 |
+
return x, m, s
|
363 |
+
|
364 |
+
def preprocess(self, x):
|
365 |
+
m0, s0 = self.ref_from_stat
|
366 |
+
m1, s1 = self.ref_to_stat
|
367 |
+
y = ((x - m0) / s0) * s1 + m1
|
368 |
+
return y
|
369 |
+
|
370 |
+
def __call__(self, xin, keep=0, simple=False):
|
371 |
+
xin, _, _ = self.get_data_and_stat(xin)
|
372 |
+
x = self.preprocess(xin)
|
373 |
+
if simple:
|
374 |
+
y = (x * (1 - keep) + xin * keep)
|
375 |
+
y = np.clip(y, 0, 255).astype(np.uint8)
|
376 |
+
return y
|
377 |
+
|
378 |
+
h, w = x.shape[:2]
|
379 |
+
x = x.reshape(-1, 3)
|
380 |
+
y = []
|
381 |
+
for chi in range(3):
|
382 |
+
yi = self.pdf_transfer_1d(self.ref_from[:, chi], self.ref_to[:, chi], x[:, chi])
|
383 |
+
y.append(yi)
|
384 |
+
|
385 |
+
y = np.stack(y, axis=1)
|
386 |
+
y = y.reshape(h, w, 3)
|
387 |
+
y = (y.astype(float) * (1 - keep) + xin.astype(float) * keep)
|
388 |
+
y = np.clip(y, 0, 255).astype(np.uint8)
|
389 |
+
return y
|
390 |
+
|
391 |
+
def pdf_transfer_1d(self, arr_fo, arr_to, arr_in, n=600):
|
392 |
+
arr = np.concatenate((arr_fo, arr_to))
|
393 |
+
min_v = arr.min() - 1e-6
|
394 |
+
max_v = arr.max() + 1e-6
|
395 |
+
min_vto = arr_to.min() - 1e-6
|
396 |
+
max_vto = arr_to.max() + 1e-6
|
397 |
+
xs = np.array(
|
398 |
+
[min_v + (max_v - min_v) * i / n for i in range(n + 1)])
|
399 |
+
hist_fo, _ = np.histogram(arr_fo, xs)
|
400 |
+
hist_to, _ = np.histogram(arr_to, xs)
|
401 |
+
xs = xs[:-1]
|
402 |
+
# compute probability distribution
|
403 |
+
cum_fo = np.cumsum(hist_fo)
|
404 |
+
cum_to = np.cumsum(hist_to)
|
405 |
+
d_fo = cum_fo / cum_fo[-1]
|
406 |
+
d_to = cum_to / cum_to[-1]
|
407 |
+
# transfer
|
408 |
+
t_d = np.interp(d_fo, d_to, xs)
|
409 |
+
t_d[d_fo <= d_to[0]] = min_vto
|
410 |
+
t_d[d_fo >= d_to[-1]] = max_vto
|
411 |
+
arr_out = np.interp(arr_in, xs, t_d)
|
412 |
+
return arr_out
|
core/models/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .common.get_model import get_model
|
2 |
+
from .common.get_optimizer import get_optimizer
|
3 |
+
from .common.get_scheduler import get_scheduler
|
4 |
+
from .common.utils import get_unit
|
core/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (367 Bytes). View file
|
|
core/models/__pycache__/codi.cpython-38.pyc
ADDED
Binary file (7.7 kB). View file
|
|
core/models/__pycache__/codi_2.cpython-38.pyc
ADDED
Binary file (7.12 kB). View file
|
|
core/models/__pycache__/dani_model.cpython-38.pyc
ADDED
Binary file (4.29 kB). View file
|
|
core/models/__pycache__/ema.cpython-38.pyc
ADDED
Binary file (2.99 kB). View file
|
|
core/models/__pycache__/model_module_infer.cpython-38.pyc
ADDED
Binary file (4.31 kB). View file
|
|
core/models/__pycache__/sd.cpython-38.pyc
ADDED
Binary file (9.82 kB). View file
|
|
core/models/codi.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import numpy as np
|
8 |
+
import numpy.random as npr
|
9 |
+
import copy
|
10 |
+
from functools import partial
|
11 |
+
from contextlib import contextmanager
|
12 |
+
|
13 |
+
from .common.get_model import get_model, register
|
14 |
+
from .sd import DDPM
|
15 |
+
|
16 |
+
version = '0'
|
17 |
+
symbol = 'codi'
|
18 |
+
|
19 |
+
|
20 |
+
@register('codi', version)
|
21 |
+
class CoDi(DDPM):
|
22 |
+
def __init__(self,
|
23 |
+
audioldm_cfg=None,
|
24 |
+
autokl_cfg=None,
|
25 |
+
optimus_cfg=None,
|
26 |
+
clip_cfg=None,
|
27 |
+
clap_cfg=None,
|
28 |
+
vision_scale_factor=0.1812,
|
29 |
+
text_scale_factor=4.3108,
|
30 |
+
audio_scale_factor=0.9228,
|
31 |
+
scale_by_std=False,
|
32 |
+
*args,
|
33 |
+
**kwargs):
|
34 |
+
super().__init__(*args, **kwargs)
|
35 |
+
|
36 |
+
if audioldm_cfg is not None:
|
37 |
+
self.audioldm = get_model()(audioldm_cfg)
|
38 |
+
|
39 |
+
if autokl_cfg is not None:
|
40 |
+
self.autokl = get_model()(autokl_cfg)
|
41 |
+
|
42 |
+
if optimus_cfg is not None:
|
43 |
+
self.optimus = get_model()(optimus_cfg)
|
44 |
+
|
45 |
+
if clip_cfg is not None:
|
46 |
+
self.clip = get_model()(clip_cfg)
|
47 |
+
|
48 |
+
if clap_cfg is not None:
|
49 |
+
self.clap = get_model()(clap_cfg)
|
50 |
+
|
51 |
+
if not scale_by_std:
|
52 |
+
self.vision_scale_factor = vision_scale_factor
|
53 |
+
self.text_scale_factor = text_scale_factor
|
54 |
+
self.audio_scale_factor = audio_scale_factor
|
55 |
+
else:
|
56 |
+
self.register_buffer("text_scale_factor", torch.tensor(text_scale_factor))
|
57 |
+
self.register_buffer("audio_scale_factor", torch.tensor(audio_scale_factor))
|
58 |
+
self.register_buffer('vision_scale_factor', torch.tensor(vision_scale_factor))
|
59 |
+
|
60 |
+
@property
|
61 |
+
def device(self):
|
62 |
+
return next(self.parameters()).device
|
63 |
+
|
64 |
+
@torch.no_grad()
|
65 |
+
def autokl_encode(self, image):
|
66 |
+
encoder_posterior = self.autokl.encode(image)
|
67 |
+
z = encoder_posterior.sample().to(image.dtype)
|
68 |
+
return self.vision_scale_factor * z
|
69 |
+
|
70 |
+
@torch.no_grad()
|
71 |
+
def autokl_decode(self, z):
|
72 |
+
z = 1. / self.vision_scale_factor * z
|
73 |
+
return self.autokl.decode(z)
|
74 |
+
|
75 |
+
@torch.no_grad()
|
76 |
+
def optimus_encode(self, text):
|
77 |
+
if isinstance(text, List):
|
78 |
+
tokenizer = self.optimus.tokenizer_encoder
|
79 |
+
token = [tokenizer.tokenize(sentence.lower()) for sentence in text]
|
80 |
+
token_id = []
|
81 |
+
for tokeni in token:
|
82 |
+
token_sentence = [tokenizer._convert_token_to_id(i) for i in tokeni]
|
83 |
+
token_sentence = tokenizer.add_special_tokens_single_sentence(token_sentence)
|
84 |
+
token_id.append(torch.LongTensor(token_sentence))
|
85 |
+
token_id = torch._C._nn.pad_sequence(token_id, batch_first=True, padding_value=0.0)[:, :512]
|
86 |
+
else:
|
87 |
+
token_id = text
|
88 |
+
z = self.optimus.encoder(token_id, attention_mask=(token_id > 0))[1]
|
89 |
+
z_mu, z_logvar = self.optimus.encoder.linear(z).chunk(2, -1)
|
90 |
+
return z_mu.squeeze(1) * self.text_scale_factor
|
91 |
+
|
92 |
+
@torch.no_grad()
|
93 |
+
def optimus_decode(self, z, temperature=1.0, max_length=30):
|
94 |
+
z = 1.0 / self.text_scale_factor * z
|
95 |
+
return self.optimus.decode(z, temperature, max_length=max_length)
|
96 |
+
|
97 |
+
@torch.no_grad()
|
98 |
+
def audioldm_encode(self, audio, time=2.0):
|
99 |
+
encoder_posterior = self.audioldm.encode(audio, time=time)
|
100 |
+
z = encoder_posterior.sample().to(audio.dtype)
|
101 |
+
return z * self.audio_scale_factor
|
102 |
+
|
103 |
+
@torch.no_grad()
|
104 |
+
def audioldm_decode(self, z):
|
105 |
+
if torch.max(torch.abs(z)) > 1e2:
|
106 |
+
z = torch.clip(z, min=-10, max=10)
|
107 |
+
z = 1.0 / self.audio_scale_factor * z
|
108 |
+
return self.audioldm.decode(z)
|
109 |
+
|
110 |
+
@torch.no_grad()
|
111 |
+
def mel_spectrogram_to_waveform(self, mel):
|
112 |
+
# Mel: [bs, 1, t-steps, fbins]
|
113 |
+
if len(mel.size()) == 4:
|
114 |
+
mel = mel.squeeze(1)
|
115 |
+
mel = mel.permute(0, 2, 1)
|
116 |
+
waveform = self.audioldm.vocoder(mel)
|
117 |
+
waveform = waveform.cpu().detach().numpy()
|
118 |
+
return waveform
|
119 |
+
|
120 |
+
@torch.no_grad()
|
121 |
+
def clip_encode_text(self, text, encode_type='encode_text'):
|
122 |
+
swap_type = self.clip.encode_type
|
123 |
+
self.clip.encode_type = encode_type
|
124 |
+
embedding = self.clip(text, encode_type)
|
125 |
+
self.clip.encode_type = swap_type
|
126 |
+
return embedding
|
127 |
+
|
128 |
+
@torch.no_grad()
|
129 |
+
def clip_encode_vision(self, vision, encode_type='encode_vision'):
|
130 |
+
swap_type = self.clip.encode_type
|
131 |
+
self.clip.encode_type = encode_type
|
132 |
+
embedding = self.clip(vision, encode_type)
|
133 |
+
self.clip.encode_type = swap_type
|
134 |
+
return embedding
|
135 |
+
|
136 |
+
@torch.no_grad()
|
137 |
+
def clap_encode_audio(self, audio):
|
138 |
+
embedding = self.clap(audio)
|
139 |
+
return embedding
|
140 |
+
|
141 |
+
def forward(self, x=None, c=None, noise=None, xtype='image', ctype='prompt', u=None, return_algined_latents=False):
|
142 |
+
if isinstance(x, list):
|
143 |
+
t = torch.randint(0, self.num_timesteps, (x[0].shape[0],), device=x[0].device).long()
|
144 |
+
else:
|
145 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=x.device).long()
|
146 |
+
return self.p_losses(x, c, t, noise, xtype, ctype, u, return_algined_latents)
|
147 |
+
|
148 |
+
def apply_model(self, x_noisy, t, cond, xtype='image', ctype='text', u=None, return_algined_latents=False):
|
149 |
+
return self.model.diffusion_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents)
|
150 |
+
|
151 |
+
def get_pixel_loss(self, pred, target, mean=True):
|
152 |
+
if self.loss_type == 'l1':
|
153 |
+
loss = (target - pred).abs()
|
154 |
+
if mean:
|
155 |
+
loss = loss.mean()
|
156 |
+
elif self.loss_type == 'l2':
|
157 |
+
if mean:
|
158 |
+
loss = torch.nn.functional.mse_loss(target, pred)
|
159 |
+
else:
|
160 |
+
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
161 |
+
else:
|
162 |
+
raise NotImplementedError("unknown loss type '{loss_type}'")
|
163 |
+
loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=-0.0)
|
164 |
+
return loss
|
165 |
+
|
166 |
+
def get_text_loss(self, pred, target):
|
167 |
+
if self.loss_type == 'l1':
|
168 |
+
loss = (target - pred).abs()
|
169 |
+
elif self.loss_type == 'l2':
|
170 |
+
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
171 |
+
loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=0.0)
|
172 |
+
return loss
|
173 |
+
|
174 |
+
def p_losses(self, x_start, cond, t, noise=None, xtype='image', ctype='prompt', u=None, return_algined_latents=False):
|
175 |
+
if isinstance(x_start, list):
|
176 |
+
noise = [torch.randn_like(x_start_i) for x_start_i in x_start] if noise is None else noise
|
177 |
+
x_noisy = [self.q_sample(x_start=x_start_i, t=t, noise=noise_i) for x_start_i, noise_i in zip(x_start, noise)]
|
178 |
+
model_output = self.apply_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents)
|
179 |
+
if return_algined_latents:
|
180 |
+
return model_output
|
181 |
+
|
182 |
+
loss_dict = {}
|
183 |
+
|
184 |
+
if self.parameterization == "x0":
|
185 |
+
target = x_start
|
186 |
+
elif self.parameterization == "eps":
|
187 |
+
target = noise
|
188 |
+
else:
|
189 |
+
raise NotImplementedError()
|
190 |
+
|
191 |
+
loss = 0.0
|
192 |
+
for model_output_i, target_i, xtype_i in zip(model_output, target, xtype):
|
193 |
+
if xtype_i == 'image':
|
194 |
+
loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3])
|
195 |
+
elif xtype_i == 'video':
|
196 |
+
loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3, 4])
|
197 |
+
elif xtype_i == 'text':
|
198 |
+
loss_simple = self.get_text_loss(model_output_i, target_i).mean([1])
|
199 |
+
elif xtype_i == 'audio':
|
200 |
+
loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3])
|
201 |
+
loss += loss_simple.mean()
|
202 |
+
return loss / len(xtype)
|
203 |
+
|
204 |
+
else:
|
205 |
+
noise = torch.randn_like(x_start) if noise is None else noise
|
206 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
207 |
+
model_output = self.apply_model(x_noisy, t, cond, xtype, ctype)
|
208 |
+
|
209 |
+
loss_dict = {}
|
210 |
+
|
211 |
+
if self.parameterization == "x0":
|
212 |
+
target = x_start
|
213 |
+
elif self.parameterization == "eps":
|
214 |
+
target = noise
|
215 |
+
else:
|
216 |
+
raise NotImplementedError()
|
217 |
+
|
218 |
+
if xtype == 'image':
|
219 |
+
loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3])
|
220 |
+
elif xtype == 'video':
|
221 |
+
loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3, 4])
|
222 |
+
elif xtype == 'text':
|
223 |
+
loss_simple = self.get_text_loss(model_output, target).mean([1])
|
224 |
+
elif xtype == 'audio':
|
225 |
+
loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3])
|
226 |
+
loss = loss_simple.mean()
|
227 |
+
return loss
|
core/models/codi_2.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import numpy as np
|
8 |
+
import numpy.random as npr
|
9 |
+
import copy
|
10 |
+
from functools import partial
|
11 |
+
from contextlib import contextmanager
|
12 |
+
|
13 |
+
from .common.get_model import get_model, register
|
14 |
+
from .sd import DDPM
|
15 |
+
|
16 |
+
version = '0'
|
17 |
+
symbol = 'thesis_model'
|
18 |
+
|
19 |
+
|
20 |
+
@register('thesis_model', version)
|
21 |
+
class CoDi(DDPM):
|
22 |
+
def __init__(self,
|
23 |
+
autokl_cfg=None,
|
24 |
+
optimus_cfg=None,
|
25 |
+
clip_cfg=None,
|
26 |
+
vision_scale_factor=0.1812,
|
27 |
+
text_scale_factor=4.3108,
|
28 |
+
audio_scale_factor=0.9228,
|
29 |
+
scale_by_std=False,
|
30 |
+
*args,
|
31 |
+
**kwargs):
|
32 |
+
super().__init__(*args, **kwargs)
|
33 |
+
|
34 |
+
if autokl_cfg is not None:
|
35 |
+
self.autokl = get_model()(autokl_cfg)
|
36 |
+
|
37 |
+
if optimus_cfg is not None:
|
38 |
+
self.optimus = get_model()(optimus_cfg)
|
39 |
+
|
40 |
+
if clip_cfg is not None:
|
41 |
+
self.clip = get_model()(clip_cfg)
|
42 |
+
|
43 |
+
if not scale_by_std:
|
44 |
+
self.vision_scale_factor = vision_scale_factor
|
45 |
+
self.text_scale_factor = text_scale_factor
|
46 |
+
self.audio_scale_factor = audio_scale_factor
|
47 |
+
else:
|
48 |
+
self.register_buffer("text_scale_factor", torch.tensor(text_scale_factor))
|
49 |
+
self.register_buffer("audio_scale_factor", torch.tensor(audio_scale_factor))
|
50 |
+
self.register_buffer('vision_scale_factor', torch.tensor(vision_scale_factor))
|
51 |
+
|
52 |
+
@property
|
53 |
+
def device(self):
|
54 |
+
return next(self.parameters()).device
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def autokl_encode(self, image):
|
58 |
+
encoder_posterior = self.autokl.encode(image)
|
59 |
+
z = encoder_posterior.sample().to(image.dtype)
|
60 |
+
return self.vision_scale_factor * z
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def autokl_decode(self, z):
|
64 |
+
z = 1. / self.vision_scale_factor * z
|
65 |
+
return self.autokl.decode(z)
|
66 |
+
|
67 |
+
@torch.no_grad()
|
68 |
+
def optimus_encode(self, text):
|
69 |
+
if isinstance(text, List):
|
70 |
+
tokenizer = self.optimus.tokenizer_encoder
|
71 |
+
token = [tokenizer.tokenize(sentence.lower()) for sentence in text]
|
72 |
+
token_id = []
|
73 |
+
for tokeni in token:
|
74 |
+
token_sentence = [tokenizer._convert_token_to_id(i) for i in tokeni]
|
75 |
+
token_sentence = tokenizer.add_special_tokens_single_sentence(token_sentence)
|
76 |
+
token_id.append(torch.LongTensor(token_sentence))
|
77 |
+
token_id = torch._C._nn.pad_sequence(token_id, batch_first=True, padding_value=0.0)[:, :512]
|
78 |
+
else:
|
79 |
+
token_id = text
|
80 |
+
z = self.optimus.encoder(token_id, attention_mask=(token_id > 0))[1]
|
81 |
+
z_mu, z_logvar = self.optimus.encoder.linear(z).chunk(2, -1)
|
82 |
+
return z_mu.squeeze(1) * self.text_scale_factor
|
83 |
+
|
84 |
+
@torch.no_grad()
|
85 |
+
def optimus_decode(self, z, temperature=1.0):
|
86 |
+
z = 1.0 / self.text_scale_factor * z
|
87 |
+
return self.optimus.decode(z, temperature)
|
88 |
+
|
89 |
+
@torch.no_grad()
|
90 |
+
def clip_encode_text(self, text, encode_type='encode_text'):
|
91 |
+
swap_type = self.clip.encode_type
|
92 |
+
self.clip.encode_type = encode_type
|
93 |
+
embedding = self.clip(text, encode_type)
|
94 |
+
self.clip.encode_type = swap_type
|
95 |
+
return embedding
|
96 |
+
|
97 |
+
@torch.no_grad()
|
98 |
+
def clip_encode_vision(self, vision, encode_type='encode_vision'):
|
99 |
+
swap_type = self.clip.encode_type
|
100 |
+
self.clip.encode_type = encode_type
|
101 |
+
embedding = self.clip(vision, encode_type)
|
102 |
+
self.clip.encode_type = swap_type
|
103 |
+
return embedding
|
104 |
+
|
105 |
+
@torch.no_grad()
|
106 |
+
def clap_encode_audio(self, audio):
|
107 |
+
embedding = self.clap(audio)
|
108 |
+
return embedding
|
109 |
+
|
110 |
+
def forward(self, x=None, c=None, noise=None, xtype='frontal', ctype='text', u=None, return_algined_latents=False, env_enc=False):
|
111 |
+
if isinstance(x, list):
|
112 |
+
t = torch.randint(0, self.num_timesteps, (x[0].shape[0],), device=x[0].device).long()
|
113 |
+
else:
|
114 |
+
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=x.device).long()
|
115 |
+
return self.p_losses(x, c, t, noise, xtype, ctype, u, return_algined_latents, env_enc)
|
116 |
+
|
117 |
+
def apply_model(self, x_noisy, t, cond, xtype='frontal', ctype='text', u=None, return_algined_latents=False, env_enc=False):
|
118 |
+
return self.model.diffusion_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents, env_enc=env_enc)
|
119 |
+
|
120 |
+
def get_pixel_loss(self, pred, target, mean=True):
|
121 |
+
if self.loss_type == 'l1':
|
122 |
+
loss = (target - pred).abs()
|
123 |
+
if mean:
|
124 |
+
loss = loss.mean()
|
125 |
+
elif self.loss_type == 'l2':
|
126 |
+
if mean:
|
127 |
+
loss = torch.nn.functional.mse_loss(target, pred)
|
128 |
+
else:
|
129 |
+
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
130 |
+
else:
|
131 |
+
raise NotImplementedError("unknown loss type '{loss_type}'")
|
132 |
+
loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=-0.0)
|
133 |
+
return loss
|
134 |
+
|
135 |
+
def get_text_loss(self, pred, target):
|
136 |
+
if self.loss_type == 'l1':
|
137 |
+
loss = (target - pred).abs()
|
138 |
+
elif self.loss_type == 'l2':
|
139 |
+
loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
|
140 |
+
loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=0.0)
|
141 |
+
return loss
|
142 |
+
|
143 |
+
def p_losses(self, x_start, cond, t, noise=None, xtype='frontal', ctype='text', u=None,
|
144 |
+
return_algined_latents=False, env_enc=False):
|
145 |
+
if isinstance(x_start, list):
|
146 |
+
noise = [torch.randn_like(x_start_i) for x_start_i in x_start] if noise is None else noise
|
147 |
+
x_noisy = [self.q_sample(x_start=x_start_i, t=t, noise=noise_i) for x_start_i, noise_i in
|
148 |
+
zip(x_start, noise)]
|
149 |
+
if not env_enc:
|
150 |
+
model_output = self.apply_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents, env_enc)
|
151 |
+
else:
|
152 |
+
model_output, h_con = self.apply_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents, env_enc)
|
153 |
+
if return_algined_latents:
|
154 |
+
return model_output
|
155 |
+
|
156 |
+
loss_dict = {}
|
157 |
+
|
158 |
+
if self.parameterization == "x0":
|
159 |
+
target = x_start
|
160 |
+
elif self.parameterization == "eps":
|
161 |
+
target = noise
|
162 |
+
else:
|
163 |
+
raise NotImplementedError()
|
164 |
+
|
165 |
+
loss = 0.0
|
166 |
+
for model_output_i, target_i, xtype_i in zip(model_output, target, xtype):
|
167 |
+
if xtype_i == 'frontal':
|
168 |
+
loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3])
|
169 |
+
elif xtype_i == 'text':
|
170 |
+
loss_simple = self.get_text_loss(model_output_i, target_i).mean([1])
|
171 |
+
elif xtype_i == 'lateral':
|
172 |
+
loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3])
|
173 |
+
loss += loss_simple.mean()
|
174 |
+
|
175 |
+
# Controlliamo se il modello ha restituito anche h_con
|
176 |
+
# In tal caso, abbiamo le rappresentazioni latenti delle due modalità
|
177 |
+
# estratte dagli environmental encoder, essendo due tensori di dimensione batch_sizex1x1280
|
178 |
+
# possiamo utilizzarli per calcolare anche un termine di contrastive loss (crossentropy come in CLIP)
|
179 |
+
if h_con is not None:
|
180 |
+
def similarity(z_a, z_b):
|
181 |
+
return F.cosine_similarity(z_a, z_b)
|
182 |
+
|
183 |
+
z_a, z_b = h_con
|
184 |
+
|
185 |
+
z_a = z_a / z_a.norm(dim=-1, keepdim=True)
|
186 |
+
z_b = z_b / z_b.norm(dim=-1, keepdim=True)
|
187 |
+
|
188 |
+
logits_a = z_a.squeeze() @ z_b.squeeze().t()
|
189 |
+
logits_b = z_a.squeeze() @ z_b.squeeze().t()
|
190 |
+
|
191 |
+
labels = torch.arange(len(z_a)).to(z_a.device)
|
192 |
+
|
193 |
+
loss_a = F.cross_entropy(logits_a, labels)
|
194 |
+
loss_b = F.cross_entropy(logits_b, labels)
|
195 |
+
|
196 |
+
loss_con = (loss_a + loss_b) / 2
|
197 |
+
loss += loss_con
|
198 |
+
return loss / len(xtype)
|
199 |
+
|
200 |
+
else:
|
201 |
+
noise = torch.randn_like(x_start) if noise is None else noise
|
202 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
203 |
+
model_output = self.apply_model(x_noisy, t, cond, xtype, ctype)
|
204 |
+
|
205 |
+
loss_dict = {}
|
206 |
+
|
207 |
+
if self.parameterization == "x0":
|
208 |
+
target = x_start
|
209 |
+
elif self.parameterization == "eps":
|
210 |
+
target = noise
|
211 |
+
else:
|
212 |
+
raise NotImplementedError()
|
213 |
+
|
214 |
+
if xtype == 'frontal':
|
215 |
+
loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3])
|
216 |
+
elif xtype == 'text':
|
217 |
+
loss_simple = self.get_text_loss(model_output, target).mean([1])
|
218 |
+
elif xtype == 'lateral':
|
219 |
+
loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3])
|
220 |
+
loss = loss_simple.mean()
|
221 |
+
return loss
|
core/models/common/__pycache__/get_model.cpython-38.pyc
ADDED
Binary file (2.96 kB). View file
|
|
core/models/common/__pycache__/get_optimizer.cpython-38.pyc
ADDED
Binary file (1.94 kB). View file
|
|
core/models/common/__pycache__/get_scheduler.cpython-38.pyc
ADDED
Binary file (9.55 kB). View file
|
|
core/models/common/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (9.75 kB). View file
|
|
core/models/common/get_model.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from email.policy import strict
|
2 |
+
import torch
|
3 |
+
import torchvision.models
|
4 |
+
import os.path as osp
|
5 |
+
import copy
|
6 |
+
from .utils import \
|
7 |
+
get_total_param, get_total_param_sum, \
|
8 |
+
get_unit
|
9 |
+
|
10 |
+
|
11 |
+
def singleton(class_):
|
12 |
+
instances = {}
|
13 |
+
|
14 |
+
def getinstance(*args, **kwargs):
|
15 |
+
if class_ not in instances:
|
16 |
+
instances[class_] = class_(*args, **kwargs)
|
17 |
+
return instances[class_]
|
18 |
+
return getinstance
|
19 |
+
|
20 |
+
|
21 |
+
def preprocess_model_args(args):
|
22 |
+
# If args has layer_units, get the corresponding
|
23 |
+
# units.
|
24 |
+
# If args get backbone, get the backbone model.
|
25 |
+
args = copy.deepcopy(args)
|
26 |
+
if 'layer_units' in args:
|
27 |
+
layer_units = [
|
28 |
+
get_unit()(i) for i in args.layer_units
|
29 |
+
]
|
30 |
+
args.layer_units = layer_units
|
31 |
+
if 'backbone' in args:
|
32 |
+
args.backbone = get_model()(args.backbone)
|
33 |
+
return args
|
34 |
+
|
35 |
+
@singleton
|
36 |
+
class get_model(object):
|
37 |
+
def __init__(self):
|
38 |
+
self.model = {}
|
39 |
+
self.version = {}
|
40 |
+
|
41 |
+
def register(self, model, name, version='x'):
|
42 |
+
self.model[name] = model
|
43 |
+
self.version[name] = version
|
44 |
+
|
45 |
+
def __call__(self, cfg, verbose=True):
|
46 |
+
"""
|
47 |
+
Construct model based on the config.
|
48 |
+
"""
|
49 |
+
t = cfg.type
|
50 |
+
|
51 |
+
# the register is in each file
|
52 |
+
if t.find('audioldm')==0:
|
53 |
+
from ..latent_diffusion.vae import audioldm
|
54 |
+
elif t.find('autoencoderkl')==0:
|
55 |
+
from ..latent_diffusion.vae import autokl
|
56 |
+
elif t.find('optimus')==0:
|
57 |
+
from ..latent_diffusion.vae import optimus
|
58 |
+
|
59 |
+
elif t.find('clip')==0:
|
60 |
+
from ..encoders import clip
|
61 |
+
elif t.find('clap')==0:
|
62 |
+
from ..encoders import clap
|
63 |
+
|
64 |
+
elif t.find('sd')==0:
|
65 |
+
from .. import sd
|
66 |
+
elif t.find('codi')==0:
|
67 |
+
from .. import codi
|
68 |
+
elif t.find('thesis_model')==0:
|
69 |
+
from .. import codi_2
|
70 |
+
elif t.find('openai_unet')==0:
|
71 |
+
from ..latent_diffusion import diffusion_unet
|
72 |
+
elif t.find('prova')==0:
|
73 |
+
from ..latent_diffusion import diffusion_unet
|
74 |
+
|
75 |
+
args = preprocess_model_args(cfg.args)
|
76 |
+
net = self.model[t](**args)
|
77 |
+
|
78 |
+
return net
|
79 |
+
|
80 |
+
def get_version(self, name):
|
81 |
+
return self.version[name]
|
82 |
+
|
83 |
+
|
84 |
+
def register(name, version='x'):
|
85 |
+
def wrapper(class_):
|
86 |
+
get_model().register(class_, name, version)
|
87 |
+
return class_
|
88 |
+
return wrapper
|
core/models/common/get_optimizer.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.optim as optim
|
3 |
+
import numpy as np
|
4 |
+
import itertools
|
5 |
+
|
6 |
+
|
7 |
+
def singleton(class_):
|
8 |
+
instances = {}
|
9 |
+
|
10 |
+
def getinstance(*args, **kwargs):
|
11 |
+
if class_ not in instances:
|
12 |
+
instances[class_] = class_(*args, **kwargs)
|
13 |
+
return instances[class_]
|
14 |
+
return getinstance
|
15 |
+
|
16 |
+
|
17 |
+
class get_optimizer(object):
|
18 |
+
def __init__(self):
|
19 |
+
self.optimizer = {}
|
20 |
+
self.register(optim.SGD, 'sgd')
|
21 |
+
self.register(optim.Adam, 'adam')
|
22 |
+
self.register(optim.AdamW, 'adamw')
|
23 |
+
|
24 |
+
def register(self, optim, name):
|
25 |
+
self.optimizer[name] = optim
|
26 |
+
|
27 |
+
def __call__(self, net, cfg):
|
28 |
+
if cfg is None:
|
29 |
+
return None
|
30 |
+
t = cfg.type
|
31 |
+
if isinstance(net, (torch.nn.DataParallel,
|
32 |
+
torch.nn.parallel.DistributedDataParallel)):
|
33 |
+
netm = net.module
|
34 |
+
else:
|
35 |
+
netm = net
|
36 |
+
pg = getattr(netm, 'parameter_group', None)
|
37 |
+
|
38 |
+
if pg is not None:
|
39 |
+
params = []
|
40 |
+
for group_name, module_or_para in pg.items():
|
41 |
+
if not isinstance(module_or_para, list):
|
42 |
+
module_or_para = [module_or_para]
|
43 |
+
|
44 |
+
grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para]
|
45 |
+
grouped_params = itertools.chain(*grouped_params)
|
46 |
+
pg_dict = {'params': grouped_params, 'name': group_name}
|
47 |
+
params.append(pg_dict)
|
48 |
+
else:
|
49 |
+
params = net.parameters()
|
50 |
+
return self.optimizer[t](params, lr=0, **cfg.args)
|
core/models/common/get_scheduler.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.optim as optim
|
3 |
+
import numpy as np
|
4 |
+
import copy
|
5 |
+
from ... import sync
|
6 |
+
from ...cfg_holder import cfg_unique_holder as cfguh
|
7 |
+
|
8 |
+
|
9 |
+
def singleton(class_):
|
10 |
+
instances = {}
|
11 |
+
|
12 |
+
def getinstance(*args, **kwargs):
|
13 |
+
if class_ not in instances:
|
14 |
+
instances[class_] = class_(*args, **kwargs)
|
15 |
+
return instances[class_]
|
16 |
+
return getinstance
|
17 |
+
|
18 |
+
|
19 |
+
@singleton
|
20 |
+
class get_scheduler(object):
|
21 |
+
def __init__(self):
|
22 |
+
self.lr_scheduler = {}
|
23 |
+
|
24 |
+
def register(self, lrsf, name):
|
25 |
+
self.lr_scheduler[name] = lrsf
|
26 |
+
|
27 |
+
def __call__(self, cfg):
|
28 |
+
if cfg is None:
|
29 |
+
return None
|
30 |
+
if isinstance(cfg, list):
|
31 |
+
schedulers = []
|
32 |
+
for ci in cfg:
|
33 |
+
t = ci.type
|
34 |
+
schedulers.append(
|
35 |
+
self.lr_scheduler[t](**ci.args))
|
36 |
+
if len(schedulers) == 0:
|
37 |
+
raise ValueError
|
38 |
+
else:
|
39 |
+
return compose_scheduler(schedulers)
|
40 |
+
t = cfg.type
|
41 |
+
return self.lr_scheduler[t](**cfg.args)
|
42 |
+
|
43 |
+
|
44 |
+
def register(name):
|
45 |
+
def wrapper(class_):
|
46 |
+
get_scheduler().register(class_, name)
|
47 |
+
return class_
|
48 |
+
return wrapper
|
49 |
+
|
50 |
+
|
51 |
+
class template_scheduler(object):
|
52 |
+
def __init__(self, step):
|
53 |
+
self.step = step
|
54 |
+
|
55 |
+
def __getitem__(self, idx):
|
56 |
+
raise ValueError
|
57 |
+
|
58 |
+
def set_lr(self, optim, new_lr, pg_lrscale=None):
|
59 |
+
"""
|
60 |
+
Set Each parameter_groups in optim with new_lr
|
61 |
+
New_lr can be find according to the idx.
|
62 |
+
pg_lrscale tells how to scale each pg.
|
63 |
+
"""
|
64 |
+
# new_lr = self.__getitem__(idx)
|
65 |
+
pg_lrscale = copy.deepcopy(pg_lrscale)
|
66 |
+
for pg in optim.param_groups:
|
67 |
+
if pg_lrscale is None:
|
68 |
+
pg['lr'] = new_lr
|
69 |
+
else:
|
70 |
+
pg['lr'] = new_lr * pg_lrscale.pop(pg['name'])
|
71 |
+
assert (pg_lrscale is None) or (len(pg_lrscale)==0), \
|
72 |
+
"pg_lrscale doesn't match pg"
|
73 |
+
|
74 |
+
@register('constant')
|
75 |
+
class constant_scheduler(template_scheduler):
|
76 |
+
def __init__(self, lr, step):
|
77 |
+
super().__init__(step)
|
78 |
+
self.lr = lr
|
79 |
+
|
80 |
+
def __getitem__(self, idx):
|
81 |
+
if idx >= self.step:
|
82 |
+
raise ValueError
|
83 |
+
return self.lr
|
84 |
+
|
85 |
+
|
86 |
+
@register('poly')
|
87 |
+
class poly_scheduler(template_scheduler):
|
88 |
+
def __init__(self, start_lr, end_lr, power, step):
|
89 |
+
super().__init__(step)
|
90 |
+
self.start_lr = start_lr
|
91 |
+
self.end_lr = end_lr
|
92 |
+
self.power = power
|
93 |
+
|
94 |
+
def __getitem__(self, idx):
|
95 |
+
if idx >= self.step:
|
96 |
+
raise ValueError
|
97 |
+
a, b = self.start_lr, self.end_lr
|
98 |
+
p, n = self.power, self.step
|
99 |
+
return b + (a-b)*((1-idx/n)**p)
|
100 |
+
|
101 |
+
|
102 |
+
@register('linear')
|
103 |
+
class linear_scheduler(template_scheduler):
|
104 |
+
def __init__(self, start_lr, end_lr, step):
|
105 |
+
super().__init__(step)
|
106 |
+
self.start_lr = start_lr
|
107 |
+
self.end_lr = end_lr
|
108 |
+
|
109 |
+
def __getitem__(self, idx):
|
110 |
+
if idx >= self.step:
|
111 |
+
raise ValueError
|
112 |
+
a, b, n = self.start_lr, self.end_lr, self.step
|
113 |
+
return b + (a-b)*(1-idx/n)
|
114 |
+
|
115 |
+
|
116 |
+
@register('multistage')
|
117 |
+
class constant_scheduler(template_scheduler):
|
118 |
+
def __init__(self, start_lr, milestones, gamma, step):
|
119 |
+
super().__init__(step)
|
120 |
+
self.start_lr = start_lr
|
121 |
+
m = [0] + milestones + [step]
|
122 |
+
lr_iter = start_lr
|
123 |
+
self.lr = []
|
124 |
+
for ms, me in zip(m[0:-1], m[1:]):
|
125 |
+
for _ in range(ms, me):
|
126 |
+
self.lr.append(lr_iter)
|
127 |
+
lr_iter *= gamma
|
128 |
+
|
129 |
+
def __getitem__(self, idx):
|
130 |
+
if idx >= self.step:
|
131 |
+
raise ValueError
|
132 |
+
return self.lr[idx]
|
133 |
+
|
134 |
+
|
135 |
+
class compose_scheduler(template_scheduler):
|
136 |
+
def __init__(self, schedulers):
|
137 |
+
self.schedulers = schedulers
|
138 |
+
self.step = [si.step for si in schedulers]
|
139 |
+
self.step_milestone = []
|
140 |
+
acc = 0
|
141 |
+
for i in self.step:
|
142 |
+
acc += i
|
143 |
+
self.step_milestone.append(acc)
|
144 |
+
self.step = sum(self.step)
|
145 |
+
|
146 |
+
def __getitem__(self, idx):
|
147 |
+
if idx >= self.step:
|
148 |
+
raise ValueError
|
149 |
+
ms = self.step_milestone
|
150 |
+
for idx, (mi, mj) in enumerate(zip(ms[:-1], ms[1:])):
|
151 |
+
if mi <= idx < mj:
|
152 |
+
return self.schedulers[idx-mi]
|
153 |
+
raise ValueError
|
154 |
+
|
155 |
+
####################
|
156 |
+
# lambda schedular #
|
157 |
+
####################
|
158 |
+
|
159 |
+
|
160 |
+
class LambdaWarmUpCosineScheduler(template_scheduler):
|
161 |
+
"""
|
162 |
+
note: use with a base_lr of 1.0
|
163 |
+
"""
|
164 |
+
def __init__(self,
|
165 |
+
base_lr,
|
166 |
+
warm_up_steps,
|
167 |
+
lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
168 |
+
cfgt = cfguh().cfg.train
|
169 |
+
bs = cfgt.batch_size
|
170 |
+
if 'gradacc_every' not in cfgt:
|
171 |
+
print('Warning, gradacc_every is not found in xml, use 1 as default.')
|
172 |
+
acc = cfgt.get('gradacc_every', 1)
|
173 |
+
self.lr_multi = base_lr * bs * acc
|
174 |
+
self.lr_warm_up_steps = warm_up_steps
|
175 |
+
self.lr_start = lr_start
|
176 |
+
self.lr_min = lr_min
|
177 |
+
self.lr_max = lr_max
|
178 |
+
self.lr_max_decay_steps = max_decay_steps
|
179 |
+
self.last_lr = 0.
|
180 |
+
self.verbosity_interval = verbosity_interval
|
181 |
+
|
182 |
+
def schedule(self, n):
|
183 |
+
if self.verbosity_interval > 0:
|
184 |
+
if n % self.verbosity_interval == 0:
|
185 |
+
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
186 |
+
if n < self.lr_warm_up_steps:
|
187 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
188 |
+
self.last_lr = lr
|
189 |
+
return lr
|
190 |
+
else:
|
191 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
192 |
+
t = min(t, 1.0)
|
193 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
194 |
+
1 + np.cos(t * np.pi))
|
195 |
+
self.last_lr = lr
|
196 |
+
return lr
|
197 |
+
|
198 |
+
def __getitem__(self, idx):
|
199 |
+
return self.schedule(idx) * self.lr_multi
|
200 |
+
|
201 |
+
|
202 |
+
class LambdaWarmUpCosineScheduler2(template_scheduler):
|
203 |
+
"""
|
204 |
+
supports repeated iterations, configurable via lists
|
205 |
+
note: use with a base_lr of 1.0.
|
206 |
+
"""
|
207 |
+
def __init__(self,
|
208 |
+
base_lr,
|
209 |
+
warm_up_steps,
|
210 |
+
f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
211 |
+
cfgt = cfguh().cfg.train
|
212 |
+
# bs = cfgt.batch_size
|
213 |
+
# if 'gradacc_every' not in cfgt:
|
214 |
+
# print('Warning, gradacc_every is not found in xml, use 1 as default.')
|
215 |
+
# acc = cfgt.get('gradacc_every', 1)
|
216 |
+
# self.lr_multi = base_lr * bs * acc
|
217 |
+
self.lr_multi = base_lr
|
218 |
+
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
219 |
+
self.lr_warm_up_steps = warm_up_steps
|
220 |
+
self.f_start = f_start
|
221 |
+
self.f_min = f_min
|
222 |
+
self.f_max = f_max
|
223 |
+
self.cycle_lengths = cycle_lengths
|
224 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
225 |
+
self.last_f = 0.
|
226 |
+
self.verbosity_interval = verbosity_interval
|
227 |
+
|
228 |
+
def find_in_interval(self, n):
|
229 |
+
interval = 0
|
230 |
+
for cl in self.cum_cycles[1:]:
|
231 |
+
if n <= cl:
|
232 |
+
return interval
|
233 |
+
interval += 1
|
234 |
+
|
235 |
+
def schedule(self, n):
|
236 |
+
cycle = self.find_in_interval(n)
|
237 |
+
n = n - self.cum_cycles[cycle]
|
238 |
+
if self.verbosity_interval > 0:
|
239 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
240 |
+
f"current cycle {cycle}")
|
241 |
+
if n < self.lr_warm_up_steps[cycle]:
|
242 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
243 |
+
self.last_f = f
|
244 |
+
return f
|
245 |
+
else:
|
246 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
247 |
+
t = min(t, 1.0)
|
248 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
249 |
+
1 + np.cos(t * np.pi))
|
250 |
+
self.last_f = f
|
251 |
+
return f
|
252 |
+
|
253 |
+
def __getitem__(self, idx):
|
254 |
+
return self.schedule(idx) * self.lr_multi
|
255 |
+
|
256 |
+
|
257 |
+
@register('stable_diffusion_linear')
|
258 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
259 |
+
def schedule(self, n):
|
260 |
+
cycle = self.find_in_interval(n)
|
261 |
+
n = n - self.cum_cycles[cycle]
|
262 |
+
if self.verbosity_interval > 0:
|
263 |
+
if n % self.verbosity_interval == 0:
|
264 |
+
print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
265 |
+
f"current cycle {cycle}")
|
266 |
+
if n < self.lr_warm_up_steps[cycle]:
|
267 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
268 |
+
self.last_f = f
|
269 |
+
return f
|
270 |
+
else:
|
271 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
272 |
+
self.last_f = f
|
273 |
+
return f
|
core/models/common/utils.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import functools
|
6 |
+
import itertools
|
7 |
+
|
8 |
+
|
9 |
+
########
|
10 |
+
# unit #
|
11 |
+
########
|
12 |
+
|
13 |
+
|
14 |
+
def singleton(class_):
|
15 |
+
instances = {}
|
16 |
+
|
17 |
+
def getinstance(*args, **kwargs):
|
18 |
+
if class_ not in instances:
|
19 |
+
instances[class_] = class_(*args, **kwargs)
|
20 |
+
return instances[class_]
|
21 |
+
|
22 |
+
return getinstance
|
23 |
+
|
24 |
+
|
25 |
+
def str2value(v):
|
26 |
+
v = v.strip()
|
27 |
+
try:
|
28 |
+
return int(v)
|
29 |
+
except:
|
30 |
+
pass
|
31 |
+
try:
|
32 |
+
return float(v)
|
33 |
+
except:
|
34 |
+
pass
|
35 |
+
if v in ('True', 'true'):
|
36 |
+
return True
|
37 |
+
elif v in ('False', 'false'):
|
38 |
+
return False
|
39 |
+
else:
|
40 |
+
return v
|
41 |
+
|
42 |
+
|
43 |
+
@singleton
|
44 |
+
class get_unit(object):
|
45 |
+
def __init__(self):
|
46 |
+
self.unit = {}
|
47 |
+
self.register('none', None)
|
48 |
+
|
49 |
+
# general convolution
|
50 |
+
self.register('conv', nn.Conv2d)
|
51 |
+
self.register('bn', nn.BatchNorm2d)
|
52 |
+
self.register('relu', nn.ReLU)
|
53 |
+
self.register('relu6', nn.ReLU6)
|
54 |
+
self.register('lrelu', nn.LeakyReLU)
|
55 |
+
self.register('dropout', nn.Dropout)
|
56 |
+
self.register('dropout2d', nn.Dropout2d)
|
57 |
+
self.register('sine', Sine)
|
58 |
+
self.register('relusine', ReLUSine)
|
59 |
+
|
60 |
+
def register(self,
|
61 |
+
name,
|
62 |
+
unitf, ):
|
63 |
+
|
64 |
+
self.unit[name] = unitf
|
65 |
+
|
66 |
+
def __call__(self, name):
|
67 |
+
if name is None:
|
68 |
+
return None
|
69 |
+
i = name.find('(')
|
70 |
+
i = len(name) if i == -1 else i
|
71 |
+
t = name[:i]
|
72 |
+
f = self.unit[t]
|
73 |
+
args = name[i:].strip('()')
|
74 |
+
if len(args) == 0:
|
75 |
+
args = {}
|
76 |
+
return f
|
77 |
+
else:
|
78 |
+
args = args.split('=')
|
79 |
+
args = [[','.join(i.split(',')[:-1]), i.split(',')[-1]] for i in args]
|
80 |
+
args = list(itertools.chain.from_iterable(args))
|
81 |
+
args = [i.strip() for i in args if len(i) > 0]
|
82 |
+
kwargs = {}
|
83 |
+
for k, v in zip(args[::2], args[1::2]):
|
84 |
+
if v[0] == '(' and v[-1] == ')':
|
85 |
+
kwargs[k] = tuple([str2value(i) for i in v.strip('()').split(',')])
|
86 |
+
elif v[0] == '[' and v[-1] == ']':
|
87 |
+
kwargs[k] = [str2value(i) for i in v.strip('[]').split(',')]
|
88 |
+
else:
|
89 |
+
kwargs[k] = str2value(v)
|
90 |
+
return functools.partial(f, **kwargs)
|
91 |
+
|
92 |
+
|
93 |
+
def register(name):
|
94 |
+
def wrapper(class_):
|
95 |
+
get_unit().register(name, class_)
|
96 |
+
return class_
|
97 |
+
|
98 |
+
return wrapper
|
99 |
+
|
100 |
+
|
101 |
+
class Sine(object):
|
102 |
+
def __init__(self, freq, gain=1):
|
103 |
+
self.freq = freq
|
104 |
+
self.gain = gain
|
105 |
+
self.repr = 'sine(freq={}, gain={})'.format(freq, gain)
|
106 |
+
|
107 |
+
def __call__(self, x, gain=1):
|
108 |
+
act_gain = self.gain * gain
|
109 |
+
return torch.sin(self.freq * x) * act_gain
|
110 |
+
|
111 |
+
def __repr__(self, ):
|
112 |
+
return self.repr
|
113 |
+
|
114 |
+
|
115 |
+
class ReLUSine(nn.Module):
|
116 |
+
def __init(self):
|
117 |
+
super().__init__()
|
118 |
+
|
119 |
+
def forward(self, input):
|
120 |
+
a = torch.sin(30 * input)
|
121 |
+
b = nn.ReLU(inplace=False)(input)
|
122 |
+
return a + b
|
123 |
+
|
124 |
+
|
125 |
+
@register('lrelu_agc')
|
126 |
+
class lrelu_agc(object):
|
127 |
+
"""
|
128 |
+
The lrelu layer with alpha, gain and clamp
|
129 |
+
"""
|
130 |
+
|
131 |
+
def __init__(self, alpha=0.1, gain=1, clamp=None):
|
132 |
+
# super().__init__()
|
133 |
+
self.alpha = alpha
|
134 |
+
if gain == 'sqrt_2':
|
135 |
+
self.gain = np.sqrt(2)
|
136 |
+
else:
|
137 |
+
self.gain = gain
|
138 |
+
self.clamp = clamp
|
139 |
+
self.repr = 'lrelu_agc(alpha={}, gain={}, clamp={})'.format(
|
140 |
+
alpha, gain, clamp)
|
141 |
+
|
142 |
+
# def forward(self, x, gain=1):
|
143 |
+
def __call__(self, x, gain=1):
|
144 |
+
x = F.leaky_relu(x, negative_slope=self.alpha, inplace=True)
|
145 |
+
act_gain = self.gain * gain
|
146 |
+
act_clamp = self.clamp * gain if self.clamp is not None else None
|
147 |
+
if act_gain != 1:
|
148 |
+
x = x * act_gain
|
149 |
+
if act_clamp is not None:
|
150 |
+
x = x.clamp(-act_clamp, act_clamp)
|
151 |
+
return x
|
152 |
+
|
153 |
+
def __repr__(self, ):
|
154 |
+
return self.repr
|
155 |
+
|
156 |
+
|
157 |
+
####################
|
158 |
+
# spatial encoding #
|
159 |
+
####################
|
160 |
+
|
161 |
+
|
162 |
+
@register('se')
|
163 |
+
class SpatialEncoding(nn.Module):
|
164 |
+
def __init__(self,
|
165 |
+
in_dim,
|
166 |
+
out_dim,
|
167 |
+
sigma=6,
|
168 |
+
cat_input=True,
|
169 |
+
require_grad=False, ):
|
170 |
+
|
171 |
+
super().__init__()
|
172 |
+
assert out_dim % (2 * in_dim) == 0, "dimension must be dividable"
|
173 |
+
|
174 |
+
n = out_dim // 2 // in_dim
|
175 |
+
m = 2 ** np.linspace(0, sigma, n)
|
176 |
+
m = np.stack([m] + [np.zeros_like(m)] * (in_dim - 1), axis=-1)
|
177 |
+
m = np.concatenate([np.roll(m, i, axis=-1) for i in range(in_dim)], axis=0)
|
178 |
+
self.emb = torch.FloatTensor(m)
|
179 |
+
if require_grad:
|
180 |
+
self.emb = nn.Parameter(self.emb, requires_grad=True)
|
181 |
+
self.in_dim = in_dim
|
182 |
+
self.out_dim = out_dim
|
183 |
+
self.sigma = sigma
|
184 |
+
self.cat_input = cat_input
|
185 |
+
self.require_grad = require_grad
|
186 |
+
|
187 |
+
def forward(self, x, format='[n x c]'):
|
188 |
+
"""
|
189 |
+
Args:
|
190 |
+
x: [n x m1],
|
191 |
+
m1 usually is 2
|
192 |
+
Outputs:
|
193 |
+
y: [n x m2]
|
194 |
+
m2 dimention number
|
195 |
+
:param format:
|
196 |
+
"""
|
197 |
+
if format == '[bs x c x 2D]':
|
198 |
+
xshape = x.shape
|
199 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
200 |
+
x = x.view(-1, x.size(-1))
|
201 |
+
elif format == '[n x c]':
|
202 |
+
pass
|
203 |
+
else:
|
204 |
+
raise ValueError
|
205 |
+
|
206 |
+
if not self.require_grad:
|
207 |
+
self.emb = self.emb.to(x.device)
|
208 |
+
y = torch.mm(x, self.emb.T)
|
209 |
+
if self.cat_input:
|
210 |
+
z = torch.cat([x, torch.sin(y), torch.cos(y)], dim=-1)
|
211 |
+
else:
|
212 |
+
z = torch.cat([torch.sin(y), torch.cos(y)], dim=-1)
|
213 |
+
|
214 |
+
if format == '[bs x c x 2D]':
|
215 |
+
z = z.view(xshape[0], xshape[2], xshape[3], -1)
|
216 |
+
z = z.permute(0, 3, 1, 2).contiguous()
|
217 |
+
return z
|
218 |
+
|
219 |
+
def extra_repr(self):
|
220 |
+
outstr = 'SpatialEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
|
221 |
+
self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
|
222 |
+
return outstr
|
223 |
+
|
224 |
+
|
225 |
+
@register('rffe')
|
226 |
+
class RFFEncoding(SpatialEncoding):
|
227 |
+
"""
|
228 |
+
Random Fourier Features
|
229 |
+
"""
|
230 |
+
|
231 |
+
def __init__(self,
|
232 |
+
in_dim,
|
233 |
+
out_dim,
|
234 |
+
sigma=6,
|
235 |
+
cat_input=True,
|
236 |
+
require_grad=False, ):
|
237 |
+
super().__init__(in_dim, out_dim, sigma, cat_input, require_grad)
|
238 |
+
n = out_dim // 2
|
239 |
+
m = np.random.normal(0, sigma, size=(n, in_dim))
|
240 |
+
self.emb = torch.FloatTensor(m)
|
241 |
+
if require_grad:
|
242 |
+
self.emb = nn.Parameter(self.emb, requires_grad=True)
|
243 |
+
|
244 |
+
def extra_repr(self):
|
245 |
+
outstr = 'RFFEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
|
246 |
+
self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
|
247 |
+
return outstr
|
248 |
+
|
249 |
+
|
250 |
+
##########
|
251 |
+
# helper #
|
252 |
+
##########
|
253 |
+
|
254 |
+
|
255 |
+
def freeze(net):
|
256 |
+
for m in net.modules():
|
257 |
+
if isinstance(m, (
|
258 |
+
nn.BatchNorm2d,
|
259 |
+
nn.SyncBatchNorm,)):
|
260 |
+
# inplace_abn not supported
|
261 |
+
m.eval()
|
262 |
+
for pi in net.parameters():
|
263 |
+
pi.requires_grad = False
|
264 |
+
return net
|
265 |
+
|
266 |
+
|
267 |
+
def common_init(m):
|
268 |
+
if isinstance(m, (
|
269 |
+
nn.Conv2d,
|
270 |
+
nn.ConvTranspose2d,)):
|
271 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
272 |
+
if m.bias is not None:
|
273 |
+
nn.init.constant_(m.bias, 0)
|
274 |
+
elif isinstance(m, (
|
275 |
+
nn.BatchNorm2d,
|
276 |
+
nn.SyncBatchNorm,)):
|
277 |
+
nn.init.constant_(m.weight, 1)
|
278 |
+
nn.init.constant_(m.bias, 0)
|
279 |
+
else:
|
280 |
+
pass
|
281 |
+
|
282 |
+
|
283 |
+
def init_module(module):
|
284 |
+
"""
|
285 |
+
Args:
|
286 |
+
module: [nn.module] list or nn.module
|
287 |
+
a list of module to be initialized.
|
288 |
+
"""
|
289 |
+
if isinstance(module, (list, tuple)):
|
290 |
+
module = list(module)
|
291 |
+
else:
|
292 |
+
module = [module]
|
293 |
+
|
294 |
+
for mi in module:
|
295 |
+
for mii in mi.modules():
|
296 |
+
common_init(mii)
|
297 |
+
|
298 |
+
|
299 |
+
def get_total_param(net):
|
300 |
+
if getattr(net, 'parameters', None) is None:
|
301 |
+
return 0
|
302 |
+
return sum(p.numel() for p in net.parameters())
|
303 |
+
|
304 |
+
|
305 |
+
def get_total_param_sum(net):
|
306 |
+
if getattr(net, 'parameters', None) is None:
|
307 |
+
return 0
|
308 |
+
with torch.no_grad():
|
309 |
+
s = sum(p.cpu().detach().numpy().sum().item() for p in net.parameters())
|
310 |
+
return s
|
core/models/dani_model.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchvision.transforms as tvtrans
|
7 |
+
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
|
12 |
+
from . import get_model
|
13 |
+
from ..cfg_helper import model_cfg_bank
|
14 |
+
from ..common.utils import regularize_image, regularize_video, remove_duplicate_word
|
15 |
+
|
16 |
+
import warnings
|
17 |
+
|
18 |
+
warnings.filterwarnings("ignore")
|
19 |
+
|
20 |
+
|
21 |
+
class dani_model(pl.LightningModule):
|
22 |
+
def __init__(self, model='thesis_model', load_weights=True, data_dir='pretrained', pth=["CoDi_encoders.pth"], fp16=False):
|
23 |
+
super().__init__()
|
24 |
+
# import torch
|
25 |
+
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
26 |
+
cfgm = model_cfg_bank()(model)
|
27 |
+
net = get_model()(cfgm)
|
28 |
+
if load_weights:
|
29 |
+
for path in pth:
|
30 |
+
net.load_state_dict(torch.load(os.path.join(data_dir, path), map_location='cpu'), strict=False)
|
31 |
+
print('Load pretrained weight from {}'.format(pth))
|
32 |
+
|
33 |
+
self.net = net
|
34 |
+
|
35 |
+
from core.models.ddim.ddim_vd import DDIMSampler_VD
|
36 |
+
self.sampler = DDIMSampler_VD(net)
|
37 |
+
|
38 |
+
def decode(self, z, xtype):
|
39 |
+
device = z.device
|
40 |
+
net = self.net
|
41 |
+
z = z.to(device)
|
42 |
+
if xtype == 'image':
|
43 |
+
x = net.autokl_decode(z)
|
44 |
+
x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
|
45 |
+
return x
|
46 |
+
|
47 |
+
elif xtype == 'video':
|
48 |
+
num_frames = z.shape[2]
|
49 |
+
z = rearrange(z, 'b c f h w -> (b f) c h w')
|
50 |
+
x = net.autokl_decode(z)
|
51 |
+
x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames)
|
52 |
+
|
53 |
+
x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
|
54 |
+
video_list = []
|
55 |
+
for video in x:
|
56 |
+
video_list.append([tvtrans.ToPILImage()(xi) for xi in video])
|
57 |
+
return video_list
|
58 |
+
|
59 |
+
elif xtype == 'text':
|
60 |
+
prompt_temperature = 1.0
|
61 |
+
prompt_merge_same_adj_word = True
|
62 |
+
x = net.optimus_decode(z, temperature=prompt_temperature)
|
63 |
+
"""
|
64 |
+
if prompt_merge_same_adj_word:
|
65 |
+
xnew = []
|
66 |
+
for xi in x:
|
67 |
+
xi_split = xi.split()
|
68 |
+
xinew = []
|
69 |
+
for idxi, wi in enumerate(xi_split):
|
70 |
+
if idxi!=0 and wi==xi_split[idxi-1]:
|
71 |
+
continue
|
72 |
+
xinew.append(wi)
|
73 |
+
xnew.append(remove_duplicate_word(' '.join(xinew)))
|
74 |
+
x = xnew
|
75 |
+
"""
|
76 |
+
return x
|
77 |
+
|
78 |
+
elif xtype == 'audio':
|
79 |
+
x = net.audioldm_decode(z)
|
80 |
+
x = net.mel_spectrogram_to_waveform(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
def forward(self, xtype=[], condition=[], condition_types=[], n_samples=1,
|
84 |
+
mix_weight={'video': 1, 'audio': 1, 'text': 1, 'image': 1}, image_size=256, ddim_steps=50, scale=7.5,
|
85 |
+
num_frames=8):
|
86 |
+
# import torch
|
87 |
+
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
88 |
+
device = self.device
|
89 |
+
net = self.net
|
90 |
+
sampler = self.sampler
|
91 |
+
ddim_eta = 0.0
|
92 |
+
|
93 |
+
conditioning = []
|
94 |
+
assert len(set(condition_types)) == len(condition_types), "we don't support condition with same modalities yet."
|
95 |
+
assert len(condition) == len(condition_types)
|
96 |
+
|
97 |
+
for i, condition_type in enumerate(condition_types):
|
98 |
+
if condition_type == 'image':
|
99 |
+
print(condition[i].shape)
|
100 |
+
ctemp1 = regularize_image(condition[i]).squeeze().to(device)
|
101 |
+
print(ctemp1.shape)
|
102 |
+
ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1)
|
103 |
+
cim = net.clip_encode_vision(ctemp1).to(device)
|
104 |
+
uim = None
|
105 |
+
if scale != 1.0:
|
106 |
+
dummy = torch.zeros_like(ctemp1).to(device)
|
107 |
+
uim = net.clip_encode_vision(dummy).to(device)
|
108 |
+
conditioning.append(torch.cat([uim, cim]))
|
109 |
+
|
110 |
+
elif condition_type == 'video':
|
111 |
+
ctemp1 = regularize_video(condition[i]).to(device)
|
112 |
+
ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1, 1)
|
113 |
+
cim = net.clip_encode_vision(ctemp1).to(device)
|
114 |
+
uim = None
|
115 |
+
if scale != 1.0:
|
116 |
+
dummy = torch.zeros_like(ctemp1).to(device)
|
117 |
+
uim = net.clip_encode_vision(dummy).to(device)
|
118 |
+
conditioning.append(torch.cat([uim, cim]))
|
119 |
+
|
120 |
+
elif condition_type == 'audio':
|
121 |
+
ctemp = condition[i][None].repeat(n_samples, 1, 1)
|
122 |
+
cad = net.clap_encode_audio(ctemp)
|
123 |
+
uad = None
|
124 |
+
if scale != 1.0:
|
125 |
+
dummy = torch.zeros_like(ctemp)
|
126 |
+
uad = net.clap_encode_audio(dummy)
|
127 |
+
conditioning.append(torch.cat([uad, cad]))
|
128 |
+
|
129 |
+
elif condition_type == 'text':
|
130 |
+
ctx = net.clip_encode_text(n_samples * [condition[i]]).to(device)
|
131 |
+
utx = None
|
132 |
+
if scale != 1.0:
|
133 |
+
utx = net.clip_encode_text(n_samples * [""]).to(device)
|
134 |
+
conditioning.append(torch.cat([utx, ctx]))
|
135 |
+
|
136 |
+
shapes = []
|
137 |
+
for xtype_i in xtype:
|
138 |
+
if xtype_i == 'image':
|
139 |
+
h, w = [image_size, image_size]
|
140 |
+
shape = [n_samples, 4, h // 8, w // 8]
|
141 |
+
elif xtype_i == 'video':
|
142 |
+
h, w = [image_size, image_size]
|
143 |
+
shape = [n_samples, 4, num_frames, h // 8, w // 8]
|
144 |
+
elif xtype_i == 'text':
|
145 |
+
n = 768
|
146 |
+
shape = [n_samples, n]
|
147 |
+
elif xtype_i == 'audio':
|
148 |
+
h, w = [256, 16]
|
149 |
+
shape = [n_samples, 8, h, w]
|
150 |
+
else:
|
151 |
+
raise
|
152 |
+
shapes.append(shape)
|
153 |
+
|
154 |
+
z, _ = sampler.sample(
|
155 |
+
steps=ddim_steps,
|
156 |
+
shape=shapes,
|
157 |
+
condition=conditioning,
|
158 |
+
unconditional_guidance_scale=scale,
|
159 |
+
xtype=xtype,
|
160 |
+
condition_types=condition_types,
|
161 |
+
eta=ddim_eta,
|
162 |
+
verbose=False,
|
163 |
+
mix_weight=mix_weight)
|
164 |
+
|
165 |
+
out_all = []
|
166 |
+
for i, xtype_i in enumerate(xtype):
|
167 |
+
z[i] = z[i].to(device)
|
168 |
+
x_i = self.decode(z[i], xtype_i)
|
169 |
+
out_all.append(x_i)
|
170 |
+
return out_all
|
core/models/ddim/__pycache__/ddim.cpython-38.pyc
ADDED
Binary file (6.27 kB). View file
|
|
core/models/ddim/__pycache__/ddim_vd.cpython-38.pyc
ADDED
Binary file (4.29 kB). View file
|
|
core/models/ddim/__pycache__/diffusion_utils.cpython-38.pyc
ADDED
Binary file (9.56 kB). View file
|
|
core/models/ddim/ddim.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SAMPLING ONLY."""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from tqdm import tqdm
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
9 |
+
|
10 |
+
|
11 |
+
class DDIMSampler(object):
|
12 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.model = model
|
15 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
16 |
+
self.schedule = schedule
|
17 |
+
|
18 |
+
def register_buffer(self, name, attr):
|
19 |
+
# import torch
|
20 |
+
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
21 |
+
device = self.model.device
|
22 |
+
|
23 |
+
if type(attr) == torch.Tensor:
|
24 |
+
if attr.device != device:
|
25 |
+
attr = attr.to(device)
|
26 |
+
setattr(self, name, attr)
|
27 |
+
|
28 |
+
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
29 |
+
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize,
|
30 |
+
num_ddim_timesteps=ddim_num_steps,
|
31 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
32 |
+
verbose=verbose)
|
33 |
+
alphas_cumprod = self.model.alphas_cumprod
|
34 |
+
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
35 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
36 |
+
|
37 |
+
self.register_buffer('betas', to_torch(self.model.betas))
|
38 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
39 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
|
40 |
+
|
41 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
42 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
|
43 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
44 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
|
45 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
46 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
47 |
+
|
48 |
+
# ddim sampling parameters
|
49 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
50 |
+
alphacums=alphas_cumprod.cpu(),
|
51 |
+
ddim_timesteps=self.ddim_timesteps,
|
52 |
+
eta=ddim_eta,verbose=verbose)
|
53 |
+
|
54 |
+
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
55 |
+
self.register_buffer('ddim_alphas', ddim_alphas)
|
56 |
+
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
57 |
+
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
|
58 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
59 |
+
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
|
60 |
+
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
61 |
+
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
|
62 |
+
|
63 |
+
@torch.no_grad()
|
64 |
+
def sample(self,
|
65 |
+
S,
|
66 |
+
batch_size,
|
67 |
+
shape,
|
68 |
+
conditioning=None,
|
69 |
+
callback=None,
|
70 |
+
normals_sequence=None,
|
71 |
+
img_callback=None,
|
72 |
+
quantize_x0=False,
|
73 |
+
eta=0.,
|
74 |
+
mask=None,
|
75 |
+
x0=None,
|
76 |
+
temperature=1.,
|
77 |
+
noise_dropout=0.,
|
78 |
+
score_corrector=None,
|
79 |
+
corrector_kwargs=None,
|
80 |
+
verbose=True,
|
81 |
+
x_T=None,
|
82 |
+
log_every_t=100,
|
83 |
+
unconditional_guidance_scale=1.,
|
84 |
+
unconditional_conditioning=None,
|
85 |
+
video_frame_share_noise=False,
|
86 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
87 |
+
**kwargs
|
88 |
+
):
|
89 |
+
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
90 |
+
device = self.model.device
|
91 |
+
|
92 |
+
if conditioning is not None:
|
93 |
+
if isinstance(conditioning, dict):
|
94 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
95 |
+
if cbs != batch_size:
|
96 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
97 |
+
else:
|
98 |
+
if conditioning.shape[0] != batch_size:
|
99 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
100 |
+
|
101 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
102 |
+
# sampling
|
103 |
+
C, H, W = shape
|
104 |
+
size = (batch_size, C, H, W)
|
105 |
+
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
106 |
+
|
107 |
+
samples, intermediates = self.ddim_sampling(conditioning, size,
|
108 |
+
callback=callback,
|
109 |
+
img_callback=img_callback,
|
110 |
+
quantize_denoised=quantize_x0,
|
111 |
+
mask=mask, x0=x0,
|
112 |
+
ddim_use_original_steps=False,
|
113 |
+
noise_dropout=noise_dropout,
|
114 |
+
temperature=temperature,
|
115 |
+
score_corrector=score_corrector,
|
116 |
+
corrector_kwargs=corrector_kwargs,
|
117 |
+
x_T=x_T,
|
118 |
+
log_every_t=log_every_t,
|
119 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
120 |
+
unconditional_conditioning=unconditional_conditioning,
|
121 |
+
)
|
122 |
+
return samples, intermediates
|
123 |
+
|
124 |
+
@torch.no_grad()
|
125 |
+
def ddim_sampling(self,
|
126 |
+
cond, shape,
|
127 |
+
x_T=None,
|
128 |
+
ddim_use_original_steps=False,
|
129 |
+
callback=None,
|
130 |
+
timesteps=None,
|
131 |
+
quantize_denoised=False,
|
132 |
+
mask=None, x0=None,
|
133 |
+
img_callback=None, log_every_t=100,
|
134 |
+
temperature=1.,
|
135 |
+
noise_dropout=0.,
|
136 |
+
score_corrector=None,
|
137 |
+
corrector_kwargs=None,
|
138 |
+
unconditional_guidance_scale=1.,
|
139 |
+
unconditional_conditioning=None,):
|
140 |
+
device = self.model.betas.device
|
141 |
+
b = shape[0]
|
142 |
+
if x_T is None:
|
143 |
+
img = torch.randn(shape, device=device)
|
144 |
+
else:
|
145 |
+
img = x_T
|
146 |
+
|
147 |
+
if timesteps is None:
|
148 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
149 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
150 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
151 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
152 |
+
|
153 |
+
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
154 |
+
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
155 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
156 |
+
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
157 |
+
|
158 |
+
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
159 |
+
|
160 |
+
for i, step in enumerate(iterator):
|
161 |
+
index = total_steps - i - 1
|
162 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
163 |
+
|
164 |
+
if mask is not None:
|
165 |
+
assert x0 is not None
|
166 |
+
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
|
167 |
+
img = img_orig * mask + (1. - mask) * img
|
168 |
+
|
169 |
+
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
|
170 |
+
quantize_denoised=quantize_denoised, temperature=temperature,
|
171 |
+
noise_dropout=noise_dropout, score_corrector=score_corrector,
|
172 |
+
corrector_kwargs=corrector_kwargs,
|
173 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
174 |
+
unconditional_conditioning=unconditional_conditioning)
|
175 |
+
img, pred_x0 = outs
|
176 |
+
if callback: callback(i)
|
177 |
+
if img_callback: img_callback(pred_x0, i)
|
178 |
+
|
179 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
180 |
+
intermediates['x_inter'].append(img)
|
181 |
+
intermediates['pred_x0'].append(pred_x0)
|
182 |
+
|
183 |
+
return img, intermediates
|
184 |
+
|
185 |
+
@torch.no_grad()
|
186 |
+
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
187 |
+
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
|
188 |
+
unconditional_guidance_scale=1., unconditional_conditioning=None):
|
189 |
+
b, *_, device = *x.shape, x.device
|
190 |
+
|
191 |
+
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
192 |
+
e_t = self.model.apply_model(x, t, c)
|
193 |
+
else:
|
194 |
+
x_in = torch.cat([x] * 2)
|
195 |
+
t_in = torch.cat([t] * 2)
|
196 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
197 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
198 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
199 |
+
|
200 |
+
if score_corrector is not None:
|
201 |
+
assert self.model.parameterization == "eps"
|
202 |
+
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
203 |
+
|
204 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
205 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
206 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
207 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
208 |
+
# select parameters corresponding to the currently considered timestep
|
209 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
210 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
211 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
212 |
+
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
213 |
+
|
214 |
+
# current prediction for x_0
|
215 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
216 |
+
if quantize_denoised:
|
217 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
218 |
+
# direction pointing to x_t
|
219 |
+
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
220 |
+
noise = sigma_t * noise_like(x, repeat_noise) * temperature
|
221 |
+
if noise_dropout > 0.:
|
222 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
223 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
224 |
+
return x_prev, pred_x0
|
core/models/ddim/ddim_vd.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
https://github.com/SHI-Labs/Versatile-Diffusion
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
|
11 |
+
|
12 |
+
from .ddim import DDIMSampler
|
13 |
+
|
14 |
+
|
15 |
+
class DDIMSampler_VD(DDIMSampler):
|
16 |
+
@torch.no_grad()
|
17 |
+
def sample(self,
|
18 |
+
steps,
|
19 |
+
shape,
|
20 |
+
xt=None,
|
21 |
+
condition=None,
|
22 |
+
unconditional_guidance_scale=1.,
|
23 |
+
xtype='image',
|
24 |
+
condition_types=['text'],
|
25 |
+
eta=0.,
|
26 |
+
temperature=1.,
|
27 |
+
mix_weight=None,
|
28 |
+
noise_dropout=0.,
|
29 |
+
verbose=True,
|
30 |
+
log_every_t=100, ):
|
31 |
+
|
32 |
+
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
|
33 |
+
print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
|
34 |
+
samples, intermediates = self.ddim_sampling(
|
35 |
+
shape,
|
36 |
+
xt=xt,
|
37 |
+
condition=condition,
|
38 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
39 |
+
xtype=xtype,
|
40 |
+
condition_types=condition_types,
|
41 |
+
ddim_use_original_steps=False,
|
42 |
+
noise_dropout=noise_dropout,
|
43 |
+
temperature=temperature,
|
44 |
+
log_every_t=log_every_t,
|
45 |
+
mix_weight=mix_weight, )
|
46 |
+
return samples, intermediates
|
47 |
+
|
48 |
+
@torch.no_grad()
|
49 |
+
def ddim_sampling(self,
|
50 |
+
shape,
|
51 |
+
xt=None,
|
52 |
+
condition=None,
|
53 |
+
unconditional_guidance_scale=1.,
|
54 |
+
xtype=['image'],
|
55 |
+
condition_types=['text'],
|
56 |
+
ddim_use_original_steps=False,
|
57 |
+
timesteps=None,
|
58 |
+
noise_dropout=0.,
|
59 |
+
temperature=1.,
|
60 |
+
mix_weight=None,
|
61 |
+
log_every_t=100, ):
|
62 |
+
|
63 |
+
device = self.model.device
|
64 |
+
dtype = condition[0][0].dtype
|
65 |
+
|
66 |
+
if isinstance(shape[0], list):
|
67 |
+
bs = shape[0][0]
|
68 |
+
else:
|
69 |
+
bs = shape[0]
|
70 |
+
if xt is None:
|
71 |
+
if isinstance(shape[0], list):
|
72 |
+
xt = [torch.randn(shape_i, device=device, dtype=dtype) for shape_i in shape]
|
73 |
+
else:
|
74 |
+
xt = torch.randn(shape, device=device, dtype=dtype)
|
75 |
+
|
76 |
+
if timesteps is None:
|
77 |
+
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
78 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
79 |
+
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
|
80 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
81 |
+
|
82 |
+
intermediates = {'pred_xt': [], 'pred_x0': []}
|
83 |
+
time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
84 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
85 |
+
# print(f"Running DDIM Sampling with {total_steps} timesteps")
|
86 |
+
|
87 |
+
pred_xt = xt
|
88 |
+
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
89 |
+
for i, step in enumerate(iterator):
|
90 |
+
index = total_steps - i - 1
|
91 |
+
ts = torch.full((bs,), step, device=device, dtype=torch.long)
|
92 |
+
|
93 |
+
outs = self.p_sample_ddim(
|
94 |
+
pred_xt,
|
95 |
+
condition,
|
96 |
+
ts, index,
|
97 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
98 |
+
xtype=xtype,
|
99 |
+
condition_types=condition_types,
|
100 |
+
use_original_steps=ddim_use_original_steps,
|
101 |
+
noise_dropout=noise_dropout,
|
102 |
+
temperature=temperature,
|
103 |
+
mix_weight=mix_weight, )
|
104 |
+
pred_xt, pred_x0 = outs
|
105 |
+
|
106 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
107 |
+
intermediates['pred_xt'].append(pred_xt)
|
108 |
+
intermediates['pred_x0'].append(pred_x0)
|
109 |
+
|
110 |
+
return pred_xt, intermediates
|
111 |
+
|
112 |
+
@torch.no_grad()
|
113 |
+
def p_sample_ddim(self, x,
|
114 |
+
condition,
|
115 |
+
t, index,
|
116 |
+
unconditional_guidance_scale=1.,
|
117 |
+
xtype=['image'],
|
118 |
+
condition_types=['text'],
|
119 |
+
repeat_noise=False,
|
120 |
+
use_original_steps=False,
|
121 |
+
noise_dropout=0.,
|
122 |
+
temperature=1.,
|
123 |
+
mix_weight=None, ):
|
124 |
+
|
125 |
+
b, *_, device = *x[0].shape, x[0].device
|
126 |
+
|
127 |
+
x_in = []
|
128 |
+
for x_i in x:
|
129 |
+
x_in.append(torch.cat([x_i] * 2))
|
130 |
+
t_in = torch.cat([t] * 2)
|
131 |
+
|
132 |
+
out = self.model.model.diffusion_model(
|
133 |
+
x_in, t_in, condition, xtype=xtype, condition_types=condition_types, mix_weight=mix_weight)
|
134 |
+
e_t = []
|
135 |
+
for out_i in out:
|
136 |
+
e_t_uncond_i, e_t_i = out_i.chunk(2)
|
137 |
+
e_t_i = e_t_uncond_i + unconditional_guidance_scale * (e_t_i - e_t_uncond_i)
|
138 |
+
e_t_i = e_t_i.to(device)
|
139 |
+
e_t.append(e_t_i)
|
140 |
+
|
141 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
142 |
+
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
143 |
+
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
144 |
+
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
145 |
+
# select parameters corresponding to the currently considered timestep
|
146 |
+
|
147 |
+
x_prev = []
|
148 |
+
pred_x0 = []
|
149 |
+
device = x[0].device
|
150 |
+
dtype = x[0].dtype
|
151 |
+
for i, xtype_i in enumerate(xtype):
|
152 |
+
if xtype_i in ['image', 'frontal', 'lateral']:
|
153 |
+
extended_shape = (b, 1, 1, 1)
|
154 |
+
elif xtype_i == 'video':
|
155 |
+
extended_shape = (b, 1, 1, 1, 1)
|
156 |
+
elif xtype_i == 'text':
|
157 |
+
extended_shape = (b, 1)
|
158 |
+
elif xtype_i == 'audio':
|
159 |
+
extended_shape = (b, 1, 1, 1)
|
160 |
+
|
161 |
+
a_t = torch.full(extended_shape, alphas[index], device=device, dtype=dtype)
|
162 |
+
a_prev = torch.full(extended_shape, alphas_prev[index], device=device, dtype=dtype)
|
163 |
+
sigma_t = torch.full(extended_shape, sigmas[index], device=device, dtype=dtype)
|
164 |
+
sqrt_one_minus_at = torch.full(extended_shape, sqrt_one_minus_alphas[index], device=device, dtype=dtype)
|
165 |
+
|
166 |
+
# current prediction for x_0
|
167 |
+
pred_x0_i = (x[i] - sqrt_one_minus_at * e_t[i]) / a_t.sqrt()
|
168 |
+
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t[i]
|
169 |
+
noise = sigma_t * noise_like(x[i], repeat_noise) * temperature
|
170 |
+
if noise_dropout > 0.:
|
171 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
172 |
+
x_prev_i = a_prev.sqrt() * pred_x0_i + dir_xt + noise
|
173 |
+
x_prev.append(x_prev_i)
|
174 |
+
pred_x0.append(pred_x0_i)
|
175 |
+
return x_prev, pred_x0
|
core/models/ddim/diffusion_utils.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
from einops import repeat
|
7 |
+
|
8 |
+
|
9 |
+
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
10 |
+
if schedule == "linear":
|
11 |
+
betas = (
|
12 |
+
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
13 |
+
)
|
14 |
+
|
15 |
+
elif schedule == "cosine":
|
16 |
+
timesteps = (
|
17 |
+
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
18 |
+
)
|
19 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
20 |
+
alphas = torch.cos(alphas).pow(2)
|
21 |
+
alphas = alphas / alphas[0]
|
22 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
23 |
+
betas = np.clip(betas, a_min=0, a_max=0.999)
|
24 |
+
|
25 |
+
elif schedule == "sqrt_linear":
|
26 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
27 |
+
elif schedule == "sqrt":
|
28 |
+
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
29 |
+
else:
|
30 |
+
raise ValueError(f"schedule '{schedule}' unknown.")
|
31 |
+
return betas.numpy()
|
32 |
+
|
33 |
+
|
34 |
+
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
35 |
+
if ddim_discr_method == 'uniform':
|
36 |
+
c = num_ddpm_timesteps // num_ddim_timesteps
|
37 |
+
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
38 |
+
elif ddim_discr_method == 'quad':
|
39 |
+
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
40 |
+
else:
|
41 |
+
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
42 |
+
|
43 |
+
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
44 |
+
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
45 |
+
if num_ddpm_timesteps != 1000:
|
46 |
+
steps_out = ddim_timesteps + 1
|
47 |
+
else:
|
48 |
+
steps_out = ddim_timesteps
|
49 |
+
if verbose:
|
50 |
+
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
51 |
+
return steps_out
|
52 |
+
|
53 |
+
|
54 |
+
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
55 |
+
# select alphas for computing the variance schedule
|
56 |
+
alphas = alphacums[ddim_timesteps]
|
57 |
+
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
58 |
+
|
59 |
+
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
60 |
+
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
61 |
+
if verbose:
|
62 |
+
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
63 |
+
print(f'For the chosen value of eta, which is {eta}, '
|
64 |
+
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
65 |
+
return sigmas, alphas, alphas_prev
|
66 |
+
|
67 |
+
|
68 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
69 |
+
"""
|
70 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
71 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
72 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
73 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
74 |
+
produces the cumulative product of (1-beta) up to that
|
75 |
+
part of the diffusion process.
|
76 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
77 |
+
prevent singularities.
|
78 |
+
"""
|
79 |
+
betas = []
|
80 |
+
for i in range(num_diffusion_timesteps):
|
81 |
+
t1 = i / num_diffusion_timesteps
|
82 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
83 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
84 |
+
return np.array(betas)
|
85 |
+
|
86 |
+
|
87 |
+
def extract_into_tensor(a, t, x_shape):
|
88 |
+
b, *_ = t.shape
|
89 |
+
out = a.gather(-1, t)
|
90 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
91 |
+
|
92 |
+
|
93 |
+
def checkpoint(func, inputs, params, flag):
|
94 |
+
"""
|
95 |
+
Evaluate a function without caching intermediate activations, allowing for
|
96 |
+
reduced memory at the expense of extra compute in the backward pass.
|
97 |
+
:param func: the function to evaluate.
|
98 |
+
:param inputs: the argument sequence to pass to `func`.
|
99 |
+
:param params: a sequence of parameters `func` depends on but does not
|
100 |
+
explicitly take as arguments.
|
101 |
+
:param flag: if False, disable gradient checkpointing.
|
102 |
+
"""
|
103 |
+
if flag:
|
104 |
+
args = tuple(inputs) + tuple(params)
|
105 |
+
return CheckpointFunction.apply(func, len(inputs), *args)
|
106 |
+
else:
|
107 |
+
return func(*inputs)
|
108 |
+
|
109 |
+
|
110 |
+
class CheckpointFunction(torch.autograd.Function):
|
111 |
+
@staticmethod
|
112 |
+
def forward(ctx, run_function, length, *args):
|
113 |
+
ctx.run_function = run_function
|
114 |
+
ctx.input_tensors = list(args[:length])
|
115 |
+
ctx.input_params = list(args[length:])
|
116 |
+
|
117 |
+
with torch.no_grad():
|
118 |
+
output_tensors = ctx.run_function(*ctx.input_tensors)
|
119 |
+
return output_tensors
|
120 |
+
|
121 |
+
@staticmethod
|
122 |
+
def backward(ctx, *output_grads):
|
123 |
+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
124 |
+
with torch.enable_grad():
|
125 |
+
# Fixes a bug where the first op in run_function modifies the
|
126 |
+
# Tensor storage in place, which is not allowed for detach()'d
|
127 |
+
# Tensors.
|
128 |
+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
129 |
+
output_tensors = ctx.run_function(*shallow_copies)
|
130 |
+
input_grads = torch.autograd.grad(
|
131 |
+
output_tensors,
|
132 |
+
ctx.input_tensors + ctx.input_params,
|
133 |
+
output_grads,
|
134 |
+
allow_unused=True,
|
135 |
+
)
|
136 |
+
del ctx.input_tensors
|
137 |
+
del ctx.input_params
|
138 |
+
del output_tensors
|
139 |
+
return (None, None) + input_grads
|
140 |
+
|
141 |
+
|
142 |
+
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
143 |
+
"""
|
144 |
+
Create sinusoidal timestep embeddings.
|
145 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
146 |
+
These may be fractional.
|
147 |
+
:param dim: the dimension of the output.
|
148 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
149 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
150 |
+
"""
|
151 |
+
if not repeat_only:
|
152 |
+
half = dim // 2
|
153 |
+
freqs = torch.exp(
|
154 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
155 |
+
).to(device=timesteps.device)
|
156 |
+
args = timesteps[:, None].float() * freqs[None]
|
157 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
158 |
+
if dim % 2:
|
159 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
160 |
+
else:
|
161 |
+
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
162 |
+
return embedding
|
163 |
+
|
164 |
+
|
165 |
+
def zero_module(module):
|
166 |
+
"""
|
167 |
+
Zero out the parameters of a module and return it.
|
168 |
+
"""
|
169 |
+
for p in module.parameters():
|
170 |
+
p.detach().zero_()
|
171 |
+
return module
|
172 |
+
|
173 |
+
|
174 |
+
def scale_module(module, scale):
|
175 |
+
"""
|
176 |
+
Scale the parameters of a module and return it.
|
177 |
+
"""
|
178 |
+
for p in module.parameters():
|
179 |
+
p.detach().mul_(scale)
|
180 |
+
return module
|
181 |
+
|
182 |
+
|
183 |
+
def mean_flat(tensor):
|
184 |
+
"""
|
185 |
+
Take the mean over all non-batch dimensions.
|
186 |
+
"""
|
187 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
188 |
+
|
189 |
+
|
190 |
+
def normalization(channels):
|
191 |
+
"""
|
192 |
+
Make a standard normalization layer.
|
193 |
+
:param channels: number of input channels.
|
194 |
+
:return: an nn.Module for normalization.
|
195 |
+
"""
|
196 |
+
return GroupNorm32(32, channels)
|
197 |
+
|
198 |
+
|
199 |
+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
200 |
+
class SiLU(nn.Module):
|
201 |
+
def forward(self, x):
|
202 |
+
return x * torch.sigmoid(x)
|
203 |
+
|
204 |
+
|
205 |
+
class GroupNorm32(nn.GroupNorm):
|
206 |
+
def forward(self, x):
|
207 |
+
# return super().forward(x.float()).type(x.dtype)
|
208 |
+
return super().forward(x)
|
209 |
+
|
210 |
+
|
211 |
+
def conv_nd(dims, *args, **kwargs):
|
212 |
+
"""
|
213 |
+
Create a 1D, 2D, or 3D convolution module.
|
214 |
+
"""
|
215 |
+
if dims == 1:
|
216 |
+
return nn.Conv1d(*args, **kwargs)
|
217 |
+
elif dims == 2:
|
218 |
+
return nn.Conv2d(*args, **kwargs)
|
219 |
+
elif dims == 3:
|
220 |
+
return nn.Conv3d(*args, **kwargs)
|
221 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
222 |
+
|
223 |
+
|
224 |
+
def linear(*args, **kwargs):
|
225 |
+
"""
|
226 |
+
Create a linear module.
|
227 |
+
"""
|
228 |
+
return nn.Linear(*args, **kwargs)
|
229 |
+
|
230 |
+
|
231 |
+
def avg_pool_nd(dims, *args, **kwargs):
|
232 |
+
"""
|
233 |
+
Create a 1D, 2D, or 3D average pooling module.
|
234 |
+
"""
|
235 |
+
if dims == 1:
|
236 |
+
return nn.AvgPool1d(*args, **kwargs)
|
237 |
+
elif dims == 2:
|
238 |
+
return nn.AvgPool2d(*args, **kwargs)
|
239 |
+
elif dims == 3:
|
240 |
+
return nn.AvgPool3d(*args, **kwargs)
|
241 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
242 |
+
|
243 |
+
|
244 |
+
class HybridConditioner(nn.Module):
|
245 |
+
|
246 |
+
def __init__(self, c_concat_config, c_crossattn_config):
|
247 |
+
super().__init__()
|
248 |
+
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
249 |
+
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
250 |
+
|
251 |
+
def forward(self, c_concat, c_crossattn):
|
252 |
+
c_concat = self.concat_conditioner(c_concat)
|
253 |
+
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
254 |
+
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
255 |
+
|
256 |
+
|
257 |
+
def noise_like(x, repeat=False):
|
258 |
+
noise = torch.randn_like(x)
|
259 |
+
if repeat:
|
260 |
+
bs = x.shape[0]
|
261 |
+
noise = noise[0:1].repeat(bs, *((1,) * (len(x.shape) - 1)))
|
262 |
+
return noise
|
263 |
+
|
264 |
+
##########################
|
265 |
+
# inherit from ldm.utils #
|
266 |
+
##########################
|
267 |
+
|
268 |
+
|
269 |
+
def count_params(model, verbose=False):
|
270 |
+
total_params = sum(p.numel() for p in model.parameters())
|
271 |
+
if verbose:
|
272 |
+
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
273 |
+
return total_params
|
core/models/ema.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class LitEma(nn.Module):
|
6 |
+
def __init__(self, model, decay=0.9999, use_num_updates=True):
|
7 |
+
super().__init__()
|
8 |
+
if decay < 0.0 or decay > 1.0:
|
9 |
+
raise ValueError('Decay must be between 0 and 1')
|
10 |
+
|
11 |
+
self.m_name2s_name = {}
|
12 |
+
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
13 |
+
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_updates
|
14 |
+
else torch.tensor(-1, dtype=torch.int))
|
15 |
+
|
16 |
+
for name, p in model.named_parameters():
|
17 |
+
if p.requires_grad:
|
18 |
+
# remove as '.'-character is not allowed in buffers
|
19 |
+
s_name = name.replace('.', '')
|
20 |
+
self.m_name2s_name.update({name: s_name})
|
21 |
+
self.register_buffer(s_name, p.clone().detach().data)
|
22 |
+
|
23 |
+
self.collected_params = []
|
24 |
+
|
25 |
+
def forward(self, model):
|
26 |
+
decay = self.decay
|
27 |
+
|
28 |
+
if self.num_updates >= 0:
|
29 |
+
self.num_updates += 1
|
30 |
+
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
31 |
+
|
32 |
+
one_minus_decay = 1.0 - decay
|
33 |
+
|
34 |
+
with torch.no_grad():
|
35 |
+
m_param = dict(model.named_parameters())
|
36 |
+
shadow_params = dict(self.named_buffers())
|
37 |
+
|
38 |
+
for key in m_param:
|
39 |
+
if m_param[key].requires_grad:
|
40 |
+
sname = self.m_name2s_name[key]
|
41 |
+
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
42 |
+
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
43 |
+
else:
|
44 |
+
assert not key in self.m_name2s_name
|
45 |
+
|
46 |
+
def copy_to(self, model):
|
47 |
+
m_param = dict(model.named_parameters())
|
48 |
+
shadow_params = dict(self.named_buffers())
|
49 |
+
for key in m_param:
|
50 |
+
if m_param[key].requires_grad:
|
51 |
+
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
52 |
+
else:
|
53 |
+
assert not key in self.m_name2s_name
|
54 |
+
|
55 |
+
def store(self, parameters):
|
56 |
+
"""
|
57 |
+
Save the current parameters for restoring later.
|
58 |
+
Args:
|
59 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
60 |
+
temporarily stored.
|
61 |
+
"""
|
62 |
+
self.collected_params = [param.clone() for param in parameters]
|
63 |
+
|
64 |
+
def restore(self, parameters):
|
65 |
+
"""
|
66 |
+
Restore the parameters stored with the `store` method.
|
67 |
+
Useful to validate the model with EMA parameters without affecting the
|
68 |
+
original optimization process. Store the parameters before the
|
69 |
+
`copy_to` method. After validation (or model saving), use this to
|
70 |
+
restore the former parameters.
|
71 |
+
Args:
|
72 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
73 |
+
updated with the stored parameters.
|
74 |
+
"""
|
75 |
+
for c_param, param in zip(self.collected_params, parameters):
|
76 |
+
param.data.copy_(c_param.data)
|
core/models/encoders/__pycache__/clap.cpython-311.pyc
ADDED
Binary file (7.09 kB). View file
|
|
core/models/encoders/__pycache__/clap.cpython-38.pyc
ADDED
Binary file (4.16 kB). View file
|
|
core/models/encoders/__pycache__/clip.cpython-311.pyc
ADDED
Binary file (10.4 kB). View file
|
|
core/models/encoders/__pycache__/clip.cpython-38.pyc
ADDED
Binary file (6 kB). View file
|
|
core/models/encoders/clap.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchaudio
|
4 |
+
|
5 |
+
from .clap_modules.open_clip import create_model
|
6 |
+
from .clap_modules.training.data import get_audio_features
|
7 |
+
|
8 |
+
from ..common.get_model import register
|
9 |
+
|
10 |
+
|
11 |
+
@register('clap_audio')
|
12 |
+
class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
|
13 |
+
"""Uses the CLAP audio encoder"""
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
pretrained_path="",
|
17 |
+
key="waveform",
|
18 |
+
sampling_rate=16000,
|
19 |
+
embed_mode="audio",
|
20 |
+
unconditional_prob=0.1,
|
21 |
+
random_mute=False,
|
22 |
+
max_random_mute_portion=0.5,
|
23 |
+
training_mode=True,
|
24 |
+
joint_embed_shape=768,
|
25 |
+
embed_shape=512,
|
26 |
+
num_layers=12,
|
27 |
+
depths=[2, 2, 6, 2],
|
28 |
+
amodel="HTSAT-large",
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.key = key
|
33 |
+
self.amodel = amodel # or 'PANN-14'
|
34 |
+
self.tmodel = "roberta" # the best text encoder in our training
|
35 |
+
self.enable_fusion = False # False if you do not want to use the fusion model
|
36 |
+
self.fusion_type = "aff_2d"
|
37 |
+
self.pretrained = pretrained_path
|
38 |
+
self.embed_mode = embed_mode
|
39 |
+
self.embed_mode_orig = embed_mode
|
40 |
+
self.sampling_rate = sampling_rate
|
41 |
+
self.unconditional_prob = unconditional_prob
|
42 |
+
self.random_mute = random_mute
|
43 |
+
self.joint_embed_shape = joint_embed_shape
|
44 |
+
self.max_random_mute_portion = max_random_mute_portion
|
45 |
+
self.training_mode = training_mode
|
46 |
+
self.model, self.model_cfg = create_model(
|
47 |
+
self.amodel,
|
48 |
+
self.tmodel,
|
49 |
+
self.pretrained,
|
50 |
+
precision="fp32",
|
51 |
+
device="cpu",
|
52 |
+
enable_fusion=self.enable_fusion,
|
53 |
+
fusion_type=self.fusion_type,
|
54 |
+
joint_embed_shape=self.joint_embed_shape,
|
55 |
+
)
|
56 |
+
|
57 |
+
def get_dtype(self):
|
58 |
+
return next(self.model.parameters()).dtype
|
59 |
+
|
60 |
+
def get_unconditional_condition(self, batchsize):
|
61 |
+
self.unconditional_token = self.model.get_text_embedding(
|
62 |
+
self.tokenizer(["", ""])
|
63 |
+
)[0:1]
|
64 |
+
return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
|
65 |
+
|
66 |
+
def batch_to_list(self, batch):
|
67 |
+
ret = []
|
68 |
+
for i in range(batch.size(0)):
|
69 |
+
ret.append(batch[i])
|
70 |
+
return ret
|
71 |
+
|
72 |
+
def make_decision(self, probability):
|
73 |
+
if float(torch.rand(1)) < probability:
|
74 |
+
return True
|
75 |
+
else:
|
76 |
+
return False
|
77 |
+
|
78 |
+
def random_uniform(self, start, end):
|
79 |
+
val = torch.rand(1).item()
|
80 |
+
return start + (end - start) * val
|
81 |
+
|
82 |
+
def _random_mute(self, waveform):
|
83 |
+
# waveform: [bs, t-steps]
|
84 |
+
t_steps = waveform.size(-1)
|
85 |
+
for i in range(waveform.size(0)):
|
86 |
+
mute_size = int(
|
87 |
+
self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
|
88 |
+
)
|
89 |
+
mute_start = int(self.random_uniform(0, t_steps - mute_size))
|
90 |
+
waveform[i, mute_start : mute_start + mute_size] = 0
|
91 |
+
return waveform
|
92 |
+
|
93 |
+
def cos_similarity(self, waveform, text):
|
94 |
+
# waveform: [bs, t_steps]
|
95 |
+
with torch.no_grad():
|
96 |
+
self.embed_mode = "audio"
|
97 |
+
audio_emb = self(waveform.cuda())
|
98 |
+
self.embed_mode = "text"
|
99 |
+
text_emb = self(text)
|
100 |
+
similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
|
101 |
+
return similarity.squeeze()
|
102 |
+
|
103 |
+
def forward(self, batch, key=None):
|
104 |
+
|
105 |
+
# the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
|
106 |
+
if self.embed_mode == "audio":
|
107 |
+
audio_dict_list = []
|
108 |
+
assert (
|
109 |
+
self.sampling_rate == 16000
|
110 |
+
), "We only support 16000 sampling rate"
|
111 |
+
# batch: [bs, 1, t-samples]
|
112 |
+
batch = torchaudio.functional.resample(
|
113 |
+
batch, orig_freq=self.sampling_rate, new_freq=48000
|
114 |
+
)
|
115 |
+
|
116 |
+
for waveform in self.batch_to_list(batch):
|
117 |
+
audio_dict = {}
|
118 |
+
audio_dict = get_audio_features(
|
119 |
+
audio_dict,
|
120 |
+
waveform.squeeze(),
|
121 |
+
480000,
|
122 |
+
data_truncating="fusion",
|
123 |
+
data_filling="repeatpad",
|
124 |
+
audio_cfg=self.model_cfg["audio_cfg"],
|
125 |
+
dtype=self.get_dtype(),
|
126 |
+
)
|
127 |
+
audio_dict_list.append(audio_dict)
|
128 |
+
# [bs, 768]
|
129 |
+
embed = self.model.get_audio_embedding(audio_dict_list)
|
130 |
+
|
131 |
+
embed = embed.unsqueeze(1)
|
132 |
+
|
133 |
+
# [bs, 1, 768]
|
134 |
+
return embed
|