dmolino commited on
Commit
9a7fe1f
·
verified ·
1 Parent(s): 7c28703

Upload 225 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/model/audioldm.yaml +24 -0
  2. configs/model/clap.yaml +10 -0
  3. configs/model/clip.yaml +22 -0
  4. configs/model/codi.yaml +23 -0
  5. configs/model/openai_unet.yaml +85 -0
  6. configs/model/optimus.yaml +107 -0
  7. configs/model/prova.yaml +85 -0
  8. configs/model/sd.yaml +20 -0
  9. configs/model/thesis_model.yaml +21 -0
  10. core/__init__.py +0 -0
  11. core/__pycache__/__init__.cpython-38.pyc +0 -0
  12. core/__pycache__/cfg_helper.cpython-38.pyc +0 -0
  13. core/__pycache__/cfg_holder.cpython-38.pyc +0 -0
  14. core/__pycache__/sync.cpython-38.pyc +0 -0
  15. core/cfg_helper.py +665 -0
  16. core/cfg_holder.py +33 -0
  17. core/common/__pycache__/utils.cpython-38.pyc +0 -0
  18. core/common/registry.py +86 -0
  19. core/common/utils.py +412 -0
  20. core/models/__init__.py +4 -0
  21. core/models/__pycache__/__init__.cpython-38.pyc +0 -0
  22. core/models/__pycache__/codi.cpython-38.pyc +0 -0
  23. core/models/__pycache__/codi_2.cpython-38.pyc +0 -0
  24. core/models/__pycache__/dani_model.cpython-38.pyc +0 -0
  25. core/models/__pycache__/ema.cpython-38.pyc +0 -0
  26. core/models/__pycache__/model_module_infer.cpython-38.pyc +0 -0
  27. core/models/__pycache__/sd.cpython-38.pyc +0 -0
  28. core/models/codi.py +227 -0
  29. core/models/codi_2.py +221 -0
  30. core/models/common/__pycache__/get_model.cpython-38.pyc +0 -0
  31. core/models/common/__pycache__/get_optimizer.cpython-38.pyc +0 -0
  32. core/models/common/__pycache__/get_scheduler.cpython-38.pyc +0 -0
  33. core/models/common/__pycache__/utils.cpython-38.pyc +0 -0
  34. core/models/common/get_model.py +88 -0
  35. core/models/common/get_optimizer.py +50 -0
  36. core/models/common/get_scheduler.py +273 -0
  37. core/models/common/utils.py +310 -0
  38. core/models/dani_model.py +170 -0
  39. core/models/ddim/__pycache__/ddim.cpython-38.pyc +0 -0
  40. core/models/ddim/__pycache__/ddim_vd.cpython-38.pyc +0 -0
  41. core/models/ddim/__pycache__/diffusion_utils.cpython-38.pyc +0 -0
  42. core/models/ddim/ddim.py +224 -0
  43. core/models/ddim/ddim_vd.py +175 -0
  44. core/models/ddim/diffusion_utils.py +273 -0
  45. core/models/ema.py +76 -0
  46. core/models/encoders/__pycache__/clap.cpython-311.pyc +0 -0
  47. core/models/encoders/__pycache__/clap.cpython-38.pyc +0 -0
  48. core/models/encoders/__pycache__/clip.cpython-311.pyc +0 -0
  49. core/models/encoders/__pycache__/clip.cpython-38.pyc +0 -0
  50. core/models/encoders/clap.py +134 -0
configs/model/audioldm.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########################
2
+ # audioldm autoencoder #
3
+ ########################
4
+
5
+
6
+ audioldm_autoencoder:
7
+ type: audioldm_autoencoder
8
+ args:
9
+ embed_dim: 8
10
+ monitor: val/rec_loss
11
+ ddconfig:
12
+ double_z: True
13
+ z_channels: 8
14
+ resolution: 256
15
+ downsample_time: False
16
+ in_channels: 1
17
+ out_ch: 1
18
+ ch: 128
19
+ ch_mult: [1, 2, 4]
20
+ num_res_blocks: 2
21
+ attn_resolutions: []
22
+ dropout: 0.0
23
+ lossconfig:
24
+ target: torch.nn.Identity
configs/model/clap.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ######################
2
+ # clap audio encoder #
3
+ ######################
4
+
5
+
6
+ clap_audio:
7
+ type: clap_audio
8
+ args:
9
+ amodel: "HTSAT-large"
10
+ joint_embed_shape: 768
configs/model/clip.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##############################
2
+ # clip vision & text encoder #
3
+ ##############################
4
+
5
+ clip:
6
+ symbol: clip
7
+ args: {}
8
+
9
+ clip_frozen:
10
+ super_cfg: clip
11
+ type: clip_frozen
12
+ args: {}
13
+
14
+ clip_text:
15
+ super_cfg: clip
16
+ type: clip_text
17
+ args: {}
18
+
19
+ clip_vision:
20
+ super_cfg: clip
21
+ type: clip_vision
22
+ args: {}
configs/model/codi.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########
2
+ # CoDi #
3
+ ########
4
+
5
+ codi:
6
+ type: codi
7
+ symbol: codi
8
+ find_unused_parameters: true
9
+ args:
10
+ audioldm_cfg: MODEL(audioldm_autoencoder)
11
+ autokl_cfg: MODEL(sd_autoencoder)
12
+ optimus_cfg: MODEL(optimus_vae)
13
+ clip_cfg: MODEL(clip_frozen)
14
+ clap_cfg: MODEL(clap_audio)
15
+ unet_config: MODEL(openai_unet_codi)
16
+ beta_linear_start: 0.00085
17
+ beta_linear_end: 0.012
18
+ timesteps: 1000
19
+ vision_scale_factor: 0.18215
20
+ text_scale_factor: 4.3108
21
+ audio_scale_factor: 0.9228
22
+ use_ema: false
23
+ parameterization : "eps"
configs/model/openai_unet.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai_unet_sd:
2
+ type: openai_unet
3
+ args:
4
+ image_size: null # no use
5
+ in_channels: 4
6
+ out_channels: 4
7
+ model_channels: 320
8
+ attention_resolutions: [ 4, 2, 1 ]
9
+ num_res_blocks: [ 2, 2, 2, 2 ]
10
+ channel_mult: [ 1, 2, 4, 4 ]
11
+ num_heads: 8
12
+ use_spatial_transformer: True
13
+ transformer_depth: 1
14
+ context_dim: 768
15
+ use_checkpoint: True
16
+ legacy: False
17
+
18
+ openai_unet_dual_context:
19
+ super_cfg: openai_unet_sd
20
+ type: openai_unet_dual_context
21
+
22
+ ########################
23
+ # Code cleaned version #
24
+ ########################
25
+
26
+ openai_unet_2d_audio:
27
+ type: openai_unet_2d
28
+ args:
29
+ input_channels: 8
30
+ model_channels: 192
31
+ output_channels: 8
32
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
33
+ channel_mult: [ 1, 2, 4, 4 ]
34
+ with_attn: [true, true, true, false]
35
+ channel_mult_connector: [1, 2, 4]
36
+ num_noattn_blocks_connector: [1, 1, 1]
37
+ with_connector: [True, True, True, False]
38
+ connector_output_channel: 1280
39
+ num_heads: 8
40
+ context_dim: 768
41
+ use_checkpoint: False
42
+
43
+ openai_unet_2d:
44
+ type: openai_unet_2d
45
+ args:
46
+ input_channels: 4
47
+ model_channels: 320
48
+ output_channels: 4
49
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
50
+ channel_mult: [ 1, 2, 4, 4 ]
51
+ with_attn: [true, true, true, false]
52
+ channel_mult_connector: [1, 2, 4]
53
+ num_noattn_blocks_connector: [1, 1, 1]
54
+ with_connector: [True, True, True, False]
55
+ connector_output_channel: 1280
56
+ num_heads: 8
57
+ context_dim: 768
58
+ use_checkpoint: True
59
+ use_video_architecture: True
60
+
61
+ openai_unet_0dmd:
62
+ type: openai_unet_0dmd
63
+ args:
64
+ input_channels: 768
65
+ model_channels: 320
66
+ output_channels: 768
67
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
68
+ channel_mult: [ 1, 2, 4, 4 ]
69
+ second_dim: [ 4, 4, 4, 4 ]
70
+ with_attn: [true, true, true, false]
71
+ num_noattn_blocks_connector: [1, 1, 1]
72
+ second_dim_connector: [4, 4, 4]
73
+ with_connector: [True, True, True, False]
74
+ connector_output_channel: 1280
75
+ num_heads: 8
76
+ context_dim: 768
77
+ use_checkpoint: True
78
+
79
+ openai_unet_codi:
80
+ type: openai_unet_codi
81
+ args:
82
+ unet_image_cfg: MODEL(openai_unet_2d)
83
+ unet_text_cfg: MODEL(openai_unet_0dmd)
84
+ unet_audio_cfg: MODEL(openai_unet_2d_audio)
85
+ model_type: ['video', 'image', 'text']
configs/model/optimus.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ optimus:
3
+ symbol: optimus
4
+ find_unused_parameters: false
5
+ args: {}
6
+
7
+ optimus_bert_encoder:
8
+ super_cfg: optimus
9
+ type: optimus_bert_connector
10
+ # pth: pretrained/optimus_bert_encoder.pth
11
+ args:
12
+ config:
13
+ architectures:
14
+ - BertForMaskedLM
15
+ attention_probs_dropout_prob: 0.1
16
+ finetuning_task: null
17
+ hidden_act: gelu
18
+ hidden_dropout_prob: 0.1
19
+ hidden_size: 768
20
+ initializer_range: 0.02
21
+ intermediate_size: 3072
22
+ layer_norm_eps: 1.e-12
23
+ max_position_embeddings: 512
24
+ num_attention_heads: 12
25
+ num_hidden_layers: 12
26
+ num_labels: 2
27
+ output_attentions: false
28
+ output_hidden_states: false
29
+ pruned_heads: {}
30
+ torchscript: false
31
+ type_vocab_size: 2
32
+ vocab_size: 28996
33
+ latent_size: 768
34
+
35
+ optimus_bert_tokenizer:
36
+ super_cfg: optimus
37
+ type: optimus_bert_tokenizer
38
+ args:
39
+ do_lower_case: false
40
+ max_len: 512
41
+ vocab_file: core/models/latent_diffusion/vae/optimus_modules/vocab/bert-base-cased-vocab.txt
42
+
43
+ optimus_gpt2_decoder:
44
+ super_cfg: optimus
45
+ type: optimus_gpt2_connector
46
+ # pth: pretrained/optimus_gpt2_decoder.pth
47
+ args:
48
+ config:
49
+ architectures:
50
+ - GPT2LMHeadModel
51
+ attn_pdrop: 0.1
52
+ embd_pdrop: 0.1
53
+ finetuning_task: null
54
+ hidden_size: 768
55
+ initializer_range: 0.02
56
+ latent_size: 768
57
+ layer_norm_epsilon: 1.e-05
58
+ max_position_embeddings: 1024
59
+ n_ctx: 1024
60
+ n_embd: 768
61
+ n_head: 12
62
+ n_layer: 12
63
+ n_positions: 1024
64
+ num_attention_heads: 12
65
+ num_hidden_layers: 12
66
+ num_labels: 1
67
+ output_attentions: false
68
+ output_hidden_states: false
69
+ pretrained_config_archive_map:
70
+ gpt2 : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json
71
+ gpt2-medium : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json
72
+ gpt2-large : https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json
73
+ pruned_heads: {}
74
+ resid_pdrop: 0.1
75
+ summary_activation: null
76
+ summary_first_dropout: 0.1
77
+ summary_proj_to_labels: true
78
+ summary_type: cls_index
79
+ summary_use_proj: true
80
+ torchscript: false
81
+ vocab_size: 50260
82
+
83
+ optimus_gpt2_tokenizer:
84
+ super_cfg: optimus
85
+ type: optimus_gpt2_tokenizer
86
+ args:
87
+ do_lower_case: false
88
+ max_len: 1024
89
+ vocab_file: core/models/latent_diffusion/vae/optimus_modules/vocab/gpt2-vocab.json
90
+ merges_file: core/models/latent_diffusion/vae/optimus_modules/vocab/gpt2-merges.txt
91
+
92
+ optimus_vae:
93
+ super_cfg: optimus
94
+ type: optimus_vae
95
+ pth: pretrained/optimus-vae.pth
96
+ args:
97
+ encoder: MODEL(optimus_bert_encoder)
98
+ decoder: MODEL(optimus_gpt2_decoder)
99
+ tokenizer_encoder: MODEL(optimus_bert_tokenizer)
100
+ tokenizer_decoder: MODEL(optimus_gpt2_tokenizer)
101
+ args:
102
+ latent_size: 768
103
+ beta: 1.0
104
+ fb_mode: 0
105
+ length_weighted_loss: false
106
+ dim_target_kl : 3.0
107
+
configs/model/prova.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai_unet_sd:
2
+ type: openai_unet
3
+ args:
4
+ image_size: null # no use
5
+ in_channels: 4
6
+ out_channels: 4
7
+ model_channels: 320
8
+ attention_resolutions: [ 4, 2, 1 ]
9
+ num_res_blocks: [ 2, 2, 2, 2 ]
10
+ channel_mult: [ 1, 2, 4, 4 ]
11
+ num_heads: 8
12
+ use_spatial_transformer: True
13
+ transformer_depth: 1
14
+ context_dim: 768
15
+ use_checkpoint: True
16
+ legacy: False
17
+
18
+ openai_unet_dual_context:
19
+ super_cfg: openai_unet_sd
20
+ type: openai_unet_dual_context
21
+
22
+ ########################
23
+ # Code cleaned version #
24
+ ########################
25
+
26
+ openai_unet_2d_audio:
27
+ type: openai_unet_2d
28
+ args:
29
+ input_channels: 8
30
+ model_channels: 192
31
+ output_channels: 8
32
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
33
+ channel_mult: [ 1, 2, 4, 4 ]
34
+ with_attn: [true, true, true, false]
35
+ channel_mult_connector: [1, 2, 4]
36
+ num_noattn_blocks_connector: [1, 1, 1]
37
+ with_connector: [True, True, True, False]
38
+ connector_output_channel: 1280
39
+ num_heads: 8
40
+ context_dim: 768
41
+ use_checkpoint: False
42
+
43
+ openai_unet_2d:
44
+ type: openai_unet_2d
45
+ args:
46
+ input_channels: 4
47
+ model_channels: 320
48
+ output_channels: 4
49
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
50
+ channel_mult: [ 1, 2, 4, 4 ]
51
+ with_attn: [true, true, true, false]
52
+ channel_mult_connector: [1, 2, 4]
53
+ num_noattn_blocks_connector: [1, 1, 1]
54
+ with_connector: [True, True, True, False]
55
+ connector_output_channel: 1280
56
+ num_heads: 8
57
+ context_dim: 768
58
+ use_checkpoint: True
59
+ use_video_architecture: True
60
+
61
+ openai_unet_0dmd:
62
+ type: openai_unet_0dmd
63
+ args:
64
+ input_channels: 768
65
+ model_channels: 320
66
+ output_channels: 768
67
+ num_noattn_blocks: [ 2, 2, 2, 2 ]
68
+ channel_mult: [ 1, 2, 4, 4 ]
69
+ second_dim: [ 4, 4, 4, 4 ]
70
+ with_attn: [true, true, true, false]
71
+ num_noattn_blocks_connector: [1, 1, 1]
72
+ second_dim_connector: [4, 4, 4]
73
+ with_connector: [True, True, True, False]
74
+ connector_output_channel: 1280
75
+ num_heads: 8
76
+ context_dim: 768
77
+ use_checkpoint: True
78
+
79
+ prova:
80
+ type: prova
81
+ args:
82
+ unet_frontal_cfg: MODEL(openai_unet_2d)
83
+ unet_lateral_cfg: MODEL(openai_unet_2d)
84
+ unet_text_cfg: MODEL(openai_unet_0dmd)
85
+ model_type: ['text']
configs/model/sd.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sd_autoencoder:
2
+ type: autoencoderkl
3
+ args:
4
+ embed_dim: 4
5
+ monitor: val/rec_loss
6
+ ddconfig:
7
+ double_z: true
8
+ z_channels: 4
9
+ resolution: 256
10
+ in_channels: 3
11
+ out_ch: 3
12
+ ch: 128
13
+ ch_mult: [1, 2, 4, 4]
14
+ num_res_blocks: 2
15
+ attn_resolutions: []
16
+ dropout: 0.0
17
+ use_video_arch: true
18
+ lossconfig:
19
+ target: torch.nn.Identity
20
+ pth: pretrained/kl-f8.pth
configs/model/thesis_model.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ########
2
+ # CoDi #
3
+ ########
4
+
5
+ thesis_model:
6
+ type: thesis_model
7
+ symbol: thesis_model
8
+ find_unused_parameters: true
9
+ args:
10
+ autokl_cfg: MODEL(sd_autoencoder)
11
+ optimus_cfg: MODEL(optimus_vae)
12
+ clip_cfg: MODEL(clip_frozen)
13
+ unet_config: MODEL(prova)
14
+ beta_linear_start: 0.00085
15
+ beta_linear_end: 0.012
16
+ timesteps: 1000
17
+ vision_scale_factor: 0.18215
18
+ text_scale_factor: 4.3108
19
+ audio_scale_factor: 0.9228
20
+ use_ema: false
21
+ parameterization : "eps"
core/__init__.py ADDED
File without changes
core/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (149 Bytes). View file
 
core/__pycache__/cfg_helper.cpython-38.pyc ADDED
Binary file (13 kB). View file
 
core/__pycache__/cfg_holder.cpython-38.pyc ADDED
Binary file (1.21 kB). View file
 
core/__pycache__/sync.cpython-38.pyc ADDED
Binary file (6.24 kB). View file
 
core/cfg_helper.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import shutil
4
+ import copy
5
+ import time
6
+ import pprint
7
+ import numpy as np
8
+ import torch
9
+ import argparse
10
+ import json
11
+ import yaml
12
+ from easydict import EasyDict as edict
13
+
14
+ from core.models import get_model
15
+
16
+ ############
17
+ # cfg_bank #
18
+ ############
19
+
20
+
21
+ def cfg_solvef(cmd, root):
22
+ if not isinstance(cmd, str):
23
+ return cmd
24
+
25
+ if cmd.find('SAME')==0:
26
+ zoom = root
27
+ p = cmd[len('SAME'):].strip('()').split('.')
28
+ p = [pi.strip() for pi in p]
29
+ for pi in p:
30
+ try:
31
+ pi = int(pi)
32
+ except:
33
+ pass
34
+
35
+ try:
36
+ zoom = zoom[pi]
37
+ except:
38
+ return cmd
39
+ return cfg_solvef(zoom, root)
40
+
41
+ if cmd.find('SEARCH')==0:
42
+ zoom = root
43
+ p = cmd[len('SEARCH'):].strip('()').split('.')
44
+ p = [pi.strip() for pi in p]
45
+ find = True
46
+ # Depth first search
47
+ for pi in p:
48
+ try:
49
+ pi = int(pi)
50
+ except:
51
+ pass
52
+
53
+ try:
54
+ zoom = zoom[pi]
55
+ except:
56
+ find = False
57
+ break
58
+
59
+ if find:
60
+ return cfg_solvef(zoom, root)
61
+ else:
62
+ if isinstance(root, dict):
63
+ for ri in root:
64
+ rv = cfg_solvef(cmd, root[ri])
65
+ if rv != cmd:
66
+ return rv
67
+ if isinstance(root, list):
68
+ for ri in root:
69
+ rv = cfg_solvef(cmd, ri)
70
+ if rv != cmd:
71
+ return rv
72
+ return cmd
73
+
74
+ if cmd.find('MODEL')==0:
75
+ goto = cmd[len('MODEL'):].strip('()')
76
+ return model_cfg_bank()(goto)
77
+
78
+ if cmd.find('DATASET')==0:
79
+ goto = cmd[len('DATASET'):].strip('()')
80
+ return dataset_cfg_bank()(goto)
81
+
82
+ return cmd
83
+
84
+
85
+ def cfg_solve(cfg, cfg_root):
86
+ # The function solve cfg element such that
87
+ # all sorrogate input are settled.
88
+ # (i.e. SAME(***) )
89
+ if isinstance(cfg, list):
90
+ for i in range(len(cfg)):
91
+ if isinstance(cfg[i], (list, dict)):
92
+ cfg[i] = cfg_solve(cfg[i], cfg_root)
93
+ else:
94
+ cfg[i] = cfg_solvef(cfg[i], cfg_root)
95
+ if isinstance(cfg, dict):
96
+ for k in cfg:
97
+ if isinstance(cfg[k], (list, dict)):
98
+ cfg[k] = cfg_solve(cfg[k], cfg_root)
99
+ else:
100
+ cfg[k] = cfg_solvef(cfg[k], cfg_root)
101
+ return cfg
102
+
103
+
104
+ class model_cfg_bank(object):
105
+ def __init__(self):
106
+ self.cfg_dir = osp.join('configs', 'model')
107
+ self.cfg_bank = edict()
108
+
109
+ def __call__(self, name):
110
+ if name not in self.cfg_bank:
111
+ cfg_path = self.get_yaml_path(name)
112
+ with open(cfg_path, 'r') as f:
113
+ cfg_new = yaml.load(
114
+ f, Loader=yaml.FullLoader)
115
+ cfg_new = edict(cfg_new)
116
+ self.cfg_bank.update(cfg_new)
117
+
118
+ cfg = self.cfg_bank[name]
119
+ cfg.name = name
120
+ if 'super_cfg' not in cfg:
121
+ cfg = cfg_solve(cfg, cfg)
122
+ self.cfg_bank[name] = cfg
123
+ return copy.deepcopy(cfg)
124
+
125
+ super_cfg = self.__call__(cfg.super_cfg)
126
+ # unlike other field,
127
+ # args will not be replaced but update.
128
+ if 'args' in cfg:
129
+ if 'args' in super_cfg:
130
+ super_cfg.args.update(cfg.args)
131
+ else:
132
+ super_cfg.args = cfg.args
133
+ cfg.pop('args')
134
+
135
+ super_cfg.update(cfg)
136
+ super_cfg.pop('super_cfg')
137
+ cfg = super_cfg
138
+ try:
139
+ delete_args = cfg.pop('delete_args')
140
+ except:
141
+ delete_args = []
142
+
143
+ for dargs in delete_args:
144
+ cfg.args.pop(dargs)
145
+
146
+ cfg = cfg_solve(cfg, cfg)
147
+ self.cfg_bank[name] = cfg
148
+ return copy.deepcopy(cfg)
149
+
150
+ def get_yaml_path(self, name):
151
+ if name.find('openai_unet')==0:
152
+ return osp.join(
153
+ self.cfg_dir, 'openai_unet.yaml')
154
+ elif name.find('prova')==0:
155
+ return osp.join(
156
+ self.cfg_dir, 'prova.yaml')
157
+ elif name.find('audioldm')==0:
158
+ return osp.join(
159
+ self.cfg_dir, 'audioldm.yaml')
160
+ elif name.find('clip')==0:
161
+ return osp.join(
162
+ self.cfg_dir, 'clip.yaml')
163
+ elif name.find('sd')==0:
164
+ return osp.join(
165
+ self.cfg_dir, 'sd.yaml')
166
+ elif name.find('codi')==0:
167
+ return osp.join(
168
+ self.cfg_dir, 'codi.yaml')
169
+ elif name.find('thesis_model')==0:
170
+ return osp.join(
171
+ self.cfg_dir, 'thesis_model.yaml')
172
+ elif name.find('clap')==0:
173
+ return osp.join(
174
+ self.cfg_dir, 'clap.yaml')
175
+ elif name.find('optimus')==0:
176
+ return osp.join(
177
+ self.cfg_dir, 'optimus.yaml')
178
+ else:
179
+ raise ValueError
180
+
181
+
182
+ class dataset_cfg_bank(object):
183
+ def __init__(self):
184
+ self.cfg_dir = osp.join('configs', 'dataset')
185
+ self.cfg_bank = edict()
186
+
187
+ def __call__(self, name):
188
+ if name not in self.cfg_bank:
189
+ cfg_path = self.get_yaml_path(name)
190
+ with open(cfg_path, 'r') as f:
191
+ cfg_new = yaml.load(
192
+ f, Loader=yaml.FullLoader)
193
+ cfg_new = edict(cfg_new)
194
+ self.cfg_bank.update(cfg_new)
195
+
196
+ cfg = self.cfg_bank[name]
197
+ cfg.name = name
198
+ if cfg.get('super_cfg', None) is None:
199
+ cfg = cfg_solve(cfg, cfg)
200
+ self.cfg_bank[name] = cfg
201
+ return copy.deepcopy(cfg)
202
+
203
+ super_cfg = self.__call__(cfg.super_cfg)
204
+ super_cfg.update(cfg)
205
+ cfg = super_cfg
206
+ cfg.super_cfg = None
207
+ try:
208
+ delete = cfg.pop('delete')
209
+ except:
210
+ delete = []
211
+
212
+ for dargs in delete:
213
+ cfg.pop(dargs)
214
+
215
+ cfg = cfg_solve(cfg, cfg)
216
+ self.cfg_bank[name] = cfg
217
+ return copy.deepcopy(cfg)
218
+
219
+ def get_yaml_path(self, name):
220
+ if name.find('cityscapes')==0:
221
+ return osp.join(
222
+ self.cfg_dir, 'cityscapes.yaml')
223
+ elif name.find('div2k')==0:
224
+ return osp.join(
225
+ self.cfg_dir, 'div2k.yaml')
226
+ elif name.find('gandiv2k')==0:
227
+ return osp.join(
228
+ self.cfg_dir, 'gandiv2k.yaml')
229
+ elif name.find('srbenchmark')==0:
230
+ return osp.join(
231
+ self.cfg_dir, 'srbenchmark.yaml')
232
+ elif name.find('imagedir')==0:
233
+ return osp.join(
234
+ self.cfg_dir, 'imagedir.yaml')
235
+ elif name.find('places2')==0:
236
+ return osp.join(
237
+ self.cfg_dir, 'places2.yaml')
238
+ elif name.find('ffhq')==0:
239
+ return osp.join(
240
+ self.cfg_dir, 'ffhq.yaml')
241
+ elif name.find('imcpt')==0:
242
+ return osp.join(
243
+ self.cfg_dir, 'imcpt.yaml')
244
+ elif name.find('texture')==0:
245
+ return osp.join(
246
+ self.cfg_dir, 'texture.yaml')
247
+ elif name.find('openimages')==0:
248
+ return osp.join(
249
+ self.cfg_dir, 'openimages.yaml')
250
+ elif name.find('laion2b')==0:
251
+ return osp.join(
252
+ self.cfg_dir, 'laion2b.yaml')
253
+ elif name.find('laionart')==0:
254
+ return osp.join(
255
+ self.cfg_dir, 'laionart.yaml')
256
+ elif name.find('celeba')==0:
257
+ return osp.join(
258
+ self.cfg_dir, 'celeba.yaml')
259
+ elif name.find('coyo')==0:
260
+ return osp.join(
261
+ self.cfg_dir, 'coyo.yaml')
262
+ elif name.find('pafc')==0:
263
+ return osp.join(
264
+ self.cfg_dir, 'pafc.yaml')
265
+ elif name.find('coco')==0:
266
+ return osp.join(
267
+ self.cfg_dir, 'coco.yaml')
268
+ else:
269
+ raise ValueError
270
+
271
+
272
+ class experiment_cfg_bank(object):
273
+ def __init__(self):
274
+ self.cfg_dir = osp.join('configs', 'experiment')
275
+ self.cfg_bank = edict()
276
+
277
+ def __call__(self, name):
278
+ if name not in self.cfg_bank:
279
+ cfg_path = self.get_yaml_path(name)
280
+ with open(cfg_path, 'r') as f:
281
+ cfg = yaml.load(
282
+ f, Loader=yaml.FullLoader)
283
+ cfg = edict(cfg)
284
+
285
+ cfg = cfg_solve(cfg, cfg)
286
+ cfg = cfg_solve(cfg, cfg)
287
+ # twice for SEARCH
288
+ self.cfg_bank[name] = cfg
289
+ return copy.deepcopy(cfg)
290
+
291
+ def get_yaml_path(self, name):
292
+ return osp.join(
293
+ self.cfg_dir, name+'.yaml')
294
+
295
+
296
+ def load_cfg_yaml(path):
297
+ if osp.isfile(path):
298
+ cfg_path = path
299
+ elif osp.isfile(osp.join('configs', 'experiment', path)):
300
+ cfg_path = osp.join('configs', 'experiment', path)
301
+ elif osp.isfile(osp.join('configs', 'experiment', path+'.yaml')):
302
+ cfg_path = osp.join('configs', 'experiment', path+'.yaml')
303
+ else:
304
+ assert False, 'No such config!'
305
+
306
+ with open(cfg_path, 'r') as f:
307
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
308
+ cfg = edict(cfg)
309
+ cfg = cfg_solve(cfg, cfg)
310
+ cfg = cfg_solve(cfg, cfg)
311
+ return cfg
312
+
313
+ ##############
314
+ # cfg_helper #
315
+ ##############
316
+
317
+
318
+ def get_experiment_id(ref=None):
319
+ if ref is None:
320
+ time.sleep(0.5)
321
+ return int(time.time()*100)
322
+ else:
323
+ try:
324
+ return int(ref)
325
+ except:
326
+ pass
327
+
328
+ _, ref = osp.split(ref)
329
+ ref = ref.split('_')[0]
330
+ try:
331
+ return int(ref)
332
+ except:
333
+ assert False, 'Invalid experiment ID!'
334
+
335
+
336
+ def record_resume_cfg(path):
337
+ cnt = 0
338
+ while True:
339
+ if osp.exists(path+'.{:04d}'.format(cnt)):
340
+ cnt += 1
341
+ continue
342
+ shutil.copyfile(path, path+'.{:04d}'.format(cnt))
343
+ break
344
+
345
+
346
+ def get_command_line_args():
347
+ parser = argparse.ArgumentParser()
348
+ parser.add_argument('--debug', action='store_true', default=False)
349
+ parser.add_argument('--config', type=str)
350
+ parser.add_argument('--gpu', nargs='+', type=int)
351
+
352
+ parser.add_argument('--node_rank', type=int, default=0)
353
+ parser.add_argument('--nodes', type=int, default=1)
354
+ parser.add_argument('--addr', type=str, default='127.0.0.1')
355
+ parser.add_argument('--port', type=int, default=11233)
356
+
357
+ parser.add_argument('--signature', nargs='+', type=str)
358
+ parser.add_argument('--seed', type=int)
359
+
360
+ parser.add_argument('--eval', type=str)
361
+ parser.add_argument('--eval_subdir', type=str)
362
+ parser.add_argument('--pretrained', type=str)
363
+
364
+ parser.add_argument('--resume_dir', type=str)
365
+ parser.add_argument('--resume_step', type=int)
366
+ parser.add_argument('--resume_weight', type=str)
367
+
368
+ args = parser.parse_args()
369
+
370
+ # Special handling the resume
371
+ if args.resume_dir is not None:
372
+ cfg = edict()
373
+ cfg.env = edict()
374
+ cfg.env.debug = args.debug
375
+ cfg.env.resume = edict()
376
+ cfg.env.resume.dir = args.resume_dir
377
+ cfg.env.resume.step = args.resume_step
378
+ cfg.env.resume.weight = args.resume_weight
379
+ return cfg
380
+
381
+ cfg = load_cfg_yaml(args.config)
382
+ cfg.env.debug = args.debug
383
+ cfg.env.gpu_device = [0] if args.gpu is None else list(args.gpu)
384
+ cfg.env.master_addr = args.addr
385
+ cfg.env.master_port = args.port
386
+ cfg.env.dist_url = 'tcp://{}:{}'.format(args.addr, args.port)
387
+ cfg.env.node_rank = args.node_rank
388
+ cfg.env.nodes = args.nodes
389
+
390
+ istrain = False if args.eval is not None else True
391
+ isdebug = cfg.env.debug
392
+
393
+ if istrain:
394
+ if isdebug:
395
+ cfg.env.experiment_id = 999999999999
396
+ cfg.train.signature = ['debug']
397
+ else:
398
+ cfg.env.experiment_id = get_experiment_id()
399
+ if args.signature is not None:
400
+ cfg.train.signature = args.signature
401
+ else:
402
+ if 'train' in cfg:
403
+ cfg.pop('train')
404
+ cfg.env.experiment_id = get_experiment_id(args.eval)
405
+ if args.signature is not None:
406
+ cfg.eval.signature = args.signature
407
+
408
+ if isdebug and (args.eval is None):
409
+ cfg.env.experiment_id = 999999999999
410
+ cfg.eval.signature = ['debug']
411
+
412
+ if args.eval_subdir is not None:
413
+ if isdebug:
414
+ cfg.eval.eval_subdir = 'debug'
415
+ else:
416
+ cfg.eval.eval_subdir = args.eval_subdir
417
+ if args.pretrained is not None:
418
+ cfg.eval.pretrained = args.pretrained
419
+ # The override pretrained over the setting in cfg.model
420
+ if args.seed is not None:
421
+ cfg.env.rnd_seed = args.seed
422
+ return cfg
423
+
424
+
425
+ def cfg_initiates(cfg):
426
+ cfge = cfg.env
427
+ isdebug = cfge.debug
428
+ isresume = 'resume' in cfge
429
+ istrain = 'train' in cfg
430
+ haseval = 'eval' in cfg
431
+ cfgt = cfg.train if istrain else None
432
+ cfgv = cfg.eval if haseval else None
433
+
434
+ ###############################
435
+ # get some environment params #
436
+ ###############################
437
+
438
+ cfge.computer = os.uname()
439
+ cfge.torch_version = str(torch.__version__)
440
+
441
+ ##########
442
+ # resume #
443
+ ##########
444
+
445
+ if isresume:
446
+ resume_cfg_path = osp.join(cfge.resume.dir, 'config.yaml')
447
+ record_resume_cfg(resume_cfg_path)
448
+ with open(resume_cfg_path, 'r') as f:
449
+ cfg_resume = yaml.load(f, Loader=yaml.FullLoader)
450
+ cfg_resume = edict(cfg_resume)
451
+ cfg_resume.env.update(cfge)
452
+ cfg = cfg_resume
453
+ cfge = cfg.env
454
+ log_file = cfg.train.log_file
455
+
456
+ print('')
457
+ print('##########')
458
+ print('# resume #')
459
+ print('##########')
460
+ print('')
461
+ with open(log_file, 'a') as f:
462
+ print('', file=f)
463
+ print('##########', file=f)
464
+ print('# resume #', file=f)
465
+ print('##########', file=f)
466
+ print('', file=f)
467
+
468
+ pprint.pprint(cfg)
469
+ with open(log_file, 'a') as f:
470
+ pprint.pprint(cfg, f)
471
+
472
+ ####################
473
+ # node distributed #
474
+ ####################
475
+
476
+ if cfg.env.master_addr!='127.0.0.1':
477
+ os.environ['MASTER_ADDR'] = cfge.master_addr
478
+ os.environ['MASTER_PORT'] = '{}'.format(cfge.master_port)
479
+ if cfg.env.dist_backend=='nccl':
480
+ os.environ['NCCL_SOCKET_FAMILY'] = 'AF_INET'
481
+ if cfg.env.dist_backend=='gloo':
482
+ os.environ['GLOO_SOCKET_FAMILY'] = 'AF_INET'
483
+
484
+ #######################
485
+ # cuda visible device #
486
+ #######################
487
+
488
+ os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
489
+ [str(gid) for gid in cfge.gpu_device])
490
+
491
+ #####################
492
+ # return resume cfg #
493
+ #####################
494
+
495
+ if isresume:
496
+ return cfg
497
+
498
+ #############################################
499
+ # some misc setting that not need in resume #
500
+ #############################################
501
+
502
+ cfgm = cfg.model
503
+ cfge.gpu_count = len(cfge.gpu_device)
504
+
505
+ ##########################################
506
+ # align batch size and num worker config #
507
+ ##########################################
508
+
509
+ gpu_n = cfge.gpu_count * cfge.nodes
510
+
511
+ def align_batch_size(bs, bs_per_gpu):
512
+ assert (bs is not None) or (bs_per_gpu is not None)
513
+ bs = bs_per_gpu * gpu_n if bs is None else bs
514
+ bs_per_gpu = bs // gpu_n if bs_per_gpu is None else bs_per_gpu
515
+ assert (bs == bs_per_gpu * gpu_n)
516
+ return bs, bs_per_gpu
517
+
518
+ if istrain:
519
+ cfgt.batch_size, cfgt.batch_size_per_gpu = \
520
+ align_batch_size(cfgt.batch_size, cfgt.batch_size_per_gpu)
521
+ cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu = \
522
+ align_batch_size(cfgt.dataset_num_workers, cfgt.dataset_num_workers_per_gpu)
523
+ if haseval:
524
+ cfgv.batch_size, cfgv.batch_size_per_gpu = \
525
+ align_batch_size(cfgv.batch_size, cfgv.batch_size_per_gpu)
526
+ cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu = \
527
+ align_batch_size(cfgv.dataset_num_workers, cfgv.dataset_num_workers_per_gpu)
528
+
529
+ ##################
530
+ # create log dir #
531
+ ##################
532
+
533
+ if istrain:
534
+ if not isdebug:
535
+ sig = cfgt.get('signature', [])
536
+ version = get_model().get_version(cfgm.type)
537
+ sig = sig + ['v{}'.format(version), 's{}'.format(cfge.rnd_seed)]
538
+ else:
539
+ sig = ['debug']
540
+
541
+ log_dir = [
542
+ cfge.log_root_dir,
543
+ '{}_{}'.format(cfgm.symbol, cfgt.dataset.symbol),
544
+ '_'.join([str(cfge.experiment_id)] + sig)
545
+ ]
546
+ log_dir = osp.join(*log_dir)
547
+ log_file = osp.join(log_dir, 'train.log')
548
+ if not osp.exists(log_file):
549
+ os.makedirs(osp.dirname(log_file))
550
+ cfgt.log_dir = log_dir
551
+ cfgt.log_file = log_file
552
+
553
+ if haseval:
554
+ cfgv.log_dir = log_dir
555
+ cfgv.log_file = log_file
556
+ else:
557
+ model_symbol = cfgm.symbol
558
+ if cfgv.get('dataset', None) is None:
559
+ dataset_symbol = 'nodataset'
560
+ else:
561
+ dataset_symbol = cfgv.dataset.symbol
562
+
563
+ log_dir = osp.join(cfge.log_root_dir, '{}_{}'.format(model_symbol, dataset_symbol))
564
+ exp_dir = search_experiment_folder(log_dir, cfge.experiment_id)
565
+ if exp_dir is None:
566
+ if not isdebug:
567
+ sig = cfgv.get('signature', []) + ['evalonly']
568
+ else:
569
+ sig = ['debug']
570
+ exp_dir = '_'.join([str(cfge.experiment_id)] + sig)
571
+
572
+ eval_subdir = cfgv.get('eval_subdir', None)
573
+ # override subdir in debug mode (if eval_subdir is set)
574
+ eval_subdir = 'debug' if (eval_subdir is not None) and isdebug else eval_subdir
575
+
576
+ if eval_subdir is not None:
577
+ log_dir = osp.join(log_dir, exp_dir, eval_subdir)
578
+ else:
579
+ log_dir = osp.join(log_dir, exp_dir)
580
+
581
+ disable_log_override = cfgv.get('disable_log_override', False)
582
+ if osp.isdir(log_dir):
583
+ if disable_log_override:
584
+ assert False, 'Override an exsited log_dir is disabled at [{}]'.format(log_dir)
585
+ else:
586
+ os.makedirs(log_dir)
587
+
588
+ log_file = osp.join(log_dir, 'eval.log')
589
+ cfgv.log_dir = log_dir
590
+ cfgv.log_file = log_file
591
+
592
+ ######################
593
+ # print and save cfg #
594
+ ######################
595
+
596
+ pprint.pprint(cfg)
597
+ with open(log_file, 'w') as f:
598
+ pprint.pprint(cfg, f)
599
+ with open(osp.join(log_dir, 'config.yaml'), 'w') as f:
600
+ yaml.dump(edict_2_dict(cfg), f)
601
+
602
+ #############
603
+ # save code #
604
+ #############
605
+
606
+ save_code = False
607
+ if istrain:
608
+ save_code = cfgt.get('save_code', False)
609
+ elif haseval:
610
+ save_code = cfgv.get('save_code', False)
611
+
612
+ if save_code:
613
+ codedir = osp.join(log_dir, 'code')
614
+ if osp.exists(codedir):
615
+ shutil.rmtree(codedir)
616
+ for d in ['configs', 'lib']:
617
+ fromcodedir = d
618
+ tocodedir = osp.join(codedir, d)
619
+ shutil.copytree(
620
+ fromcodedir, tocodedir,
621
+ ignore=shutil.ignore_patterns(
622
+ '*__pycache__*', '*build*'))
623
+ for codei in os.listdir('.'):
624
+ if osp.splitext(codei)[1] == 'py':
625
+ shutil.copy(codei, codedir)
626
+
627
+ #######################
628
+ # set matplotlib mode #
629
+ #######################
630
+
631
+ if 'matplotlib_mode' in cfge:
632
+ try:
633
+ matplotlib.use(cfge.matplotlib_mode)
634
+ except:
635
+ print('Warning: matplotlib mode [{}] failed to be set!'.format(cfge.matplotlib_mode))
636
+
637
+ return cfg
638
+
639
+
640
+ def edict_2_dict(x):
641
+ if isinstance(x, dict):
642
+ xnew = {}
643
+ for k in x:
644
+ xnew[k] = edict_2_dict(x[k])
645
+ return xnew
646
+ elif isinstance(x, list):
647
+ xnew = []
648
+ for i in range(len(x)):
649
+ xnew.append( edict_2_dict(x[i]) )
650
+ return xnew
651
+ else:
652
+ return x
653
+
654
+
655
+ def search_experiment_folder(root, exid):
656
+ target = None
657
+ for fi in os.listdir(root):
658
+ if not osp.isdir(osp.join(root, fi)):
659
+ continue
660
+ if int(fi.split('_')[0]) == exid:
661
+ if target is not None:
662
+ return None # duplicated
663
+ elif target is None:
664
+ target = fi
665
+ return target
core/cfg_holder.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+
4
+ def singleton(class_):
5
+ instances = {}
6
+
7
+ def getinstance(*args, **kwargs):
8
+ if class_ not in instances:
9
+ instances[class_] = class_(*args, **kwargs)
10
+ return instances[class_]
11
+ return getinstance
12
+
13
+ ##############
14
+ # cfg_holder #
15
+ ##############
16
+
17
+
18
+ @singleton
19
+ class cfg_unique_holder(object):
20
+ def __init__(self):
21
+ self.cfg = None
22
+ # this is use to track the main codes.
23
+ self.code = set()
24
+
25
+ def save_cfg(self, cfg):
26
+ self.cfg = copy.deepcopy(cfg)
27
+
28
+ def add_code(self, code):
29
+ """
30
+ A new main code is reached and
31
+ its name is added.
32
+ """
33
+ self.code.add(code)
core/common/__pycache__/utils.cpython-38.pyc ADDED
Binary file (11.3 kB). View file
 
core/common/registry.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from argparse import Namespace
7
+ from typing import Union
8
+
9
+ from hydra.core.config_store import ConfigStore
10
+ from omegaconf import DictConfig
11
+
12
+ REGISTRIES = {}
13
+
14
+
15
+ def setup_registry(registry_name: str,
16
+ base_class=None,
17
+ default=None,
18
+ required=False):
19
+ assert registry_name.startswith('--')
20
+ registry_name = registry_name[2:].replace('-', '_')
21
+
22
+ REGISTRY = {}
23
+ REGISTRY_CLASS_NAMES = set()
24
+ DATACLASS_REGISTRY = {}
25
+
26
+ # maintain a registry of all registries
27
+ if registry_name in REGISTRIES:
28
+ return # registry already exists
29
+ REGISTRIES[registry_name] = {
30
+ 'registry': REGISTRY,
31
+ 'default': default,
32
+ 'dataclass_registry': DATACLASS_REGISTRY,
33
+ }
34
+
35
+ def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args,
36
+ **extra_kwargs):
37
+
38
+ assert isinstance(cfg, str)
39
+ choice = cfg
40
+ if choice in DATACLASS_REGISTRY:
41
+ cfg = DATACLASS_REGISTRY[choice]()
42
+
43
+ if choice is None:
44
+ if required:
45
+ raise ValueError('{} is required!'.format(registry_name))
46
+ return None
47
+
48
+ cls = REGISTRY[choice]
49
+ if hasattr(cls, 'build_' + registry_name):
50
+ builder = getattr(cls, 'build_' + registry_name)
51
+ else:
52
+ builder = cls
53
+ return builder(cfg, *extra_args, **extra_kwargs)
54
+
55
+ def register_x(name, dataclass=None):
56
+ def register_x_cls(cls):
57
+ if name in REGISTRY:
58
+ raise ValueError('Cannot register duplicate {} ({})'.format(
59
+ registry_name, name))
60
+ if cls.__name__ in REGISTRY_CLASS_NAMES:
61
+ raise ValueError(
62
+ 'Cannot register {} with duplicate class name ({})'.format(
63
+ registry_name, cls.__name__))
64
+ if base_class is not None and not issubclass(cls, base_class):
65
+ raise ValueError('{} must extend {}'.format(
66
+ cls.__name__, base_class.__name__))
67
+
68
+ cls.__dataclass = dataclass
69
+ if cls.__dataclass is not None:
70
+ DATACLASS_REGISTRY[name] = cls.__dataclass
71
+
72
+ cs = ConfigStore.instance()
73
+ node = dataclass()
74
+ node._name = name
75
+ cs.store(name=name,
76
+ group=registry_name,
77
+ node=node,
78
+ provider='layoutlmft')
79
+
80
+ REGISTRY[name] = cls
81
+
82
+ return cls
83
+
84
+ return register_x_cls
85
+
86
+ return build_x, register_x, REGISTRY, DATACLASS_REGISTRY
core/common/utils.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+ from collections import OrderedDict
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torchvision.transforms as T
8
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor
9
+ from torchvision import transforms as tvtrans
10
+
11
+ from decord import VideoReader, cpu, gpu
12
+
13
+
14
+ ###############
15
+ # text helper #
16
+ ###############
17
+
18
+
19
+ def remove_duplicate_word(tx):
20
+ def combine_words(input, length):
21
+ combined_inputs = []
22
+ if len(splitted_input) > 1:
23
+ for i in range(len(input) - 1):
24
+ combined_inputs.append(input[i] + " " + last_word_of(splitted_input[i + 1],
25
+ length)) # add the last word of the right-neighbour (overlapping) sequence (before it has expanded), which is the next word in the original sentence
26
+ return combined_inputs, length + 1
27
+
28
+ def remove_duplicates(input, length):
29
+ bool_broke = False #this means we didn't find any duplicates here
30
+ for i in range(len(input) - length):
31
+ if input[i] == input[i + length]: #found a duplicate piece of sentence!
32
+ for j in range(0, length): #remove the overlapping sequences in reverse order
33
+ del input[i + length - j]
34
+ bool_broke = True
35
+ break #break the for loop as the loop length does not matches the length of splitted_input anymore as we removed elements
36
+ if bool_broke:
37
+ return remove_duplicates(input,
38
+ length) #if we found a duplicate, look for another duplicate of the same length
39
+ return input
40
+
41
+ def last_word_of(input, length):
42
+ splitted = input.split(" ")
43
+ if len(splitted) == 0:
44
+ return input
45
+ else:
46
+ return splitted[length - 1]
47
+
48
+ def split_and_puncsplit(text):
49
+ tx = text.split(" ")
50
+ txnew = []
51
+ for txi in tx:
52
+ txqueue = []
53
+ while True:
54
+ if txi[0] in '([{':
55
+ txqueue.extend([txi[:1], '<puncnext>'])
56
+ txi = txi[1:]
57
+ if len(txi) == 0:
58
+ break
59
+ else:
60
+ break
61
+ txnew += txqueue
62
+ txstack = []
63
+ if len(txi) == 0:
64
+ continue
65
+ while True:
66
+ if txi[-1] in '?!.,:;}])':
67
+ txstack = ['<puncnext>', txi[-1:]] + txstack
68
+ txi = txi[:-1]
69
+ if len(txi) == 0:
70
+ break
71
+ else:
72
+ break
73
+ if len(txi) != 0:
74
+ txnew += [txi]
75
+ txnew += txstack
76
+ return txnew
77
+
78
+ if tx == '':
79
+ return tx
80
+
81
+ splitted_input = split_and_puncsplit(tx)
82
+ word_length = 1
83
+ intermediate_output = False
84
+ while len(splitted_input) > 1:
85
+ splitted_input = remove_duplicates(splitted_input, word_length)
86
+ if len(splitted_input) > 1:
87
+ splitted_input, word_length = combine_words(splitted_input, word_length)
88
+ if intermediate_output:
89
+ print(splitted_input)
90
+ print(word_length)
91
+ output = splitted_input[0]
92
+ output = output.replace(' <puncnext> ', '')
93
+ return output
94
+
95
+
96
+ #################
97
+ # vision helper #
98
+ #################
99
+
100
+
101
+ def regularize_image(x, image_size=512):
102
+ BICUBIC = T.InterpolationMode.BICUBIC
103
+ if isinstance(x, str):
104
+ x = Image.open(x)
105
+ size = min(x.size)
106
+ elif isinstance(x, Image.Image):
107
+ x = x.convert('RGB')
108
+ size = min(x.size)
109
+ elif isinstance(x, np.ndarray):
110
+ x = Image.fromarray(x).convert('RGB')
111
+ size = min(x.size)
112
+ elif isinstance(x, torch.Tensor):
113
+ # normalize to [0, 1]
114
+ x = x/255.0
115
+ size = min(x.size()[1:])
116
+ else:
117
+ assert False, 'Unknown image type'
118
+
119
+ """transforms = T.Compose([
120
+ T.RandomCrop(size),
121
+ T.Resize(
122
+ (image_size, image_size),
123
+ interpolation=BICUBIC,
124
+ ),
125
+ T.RandomHorizontalFlip(),
126
+ T.ToTensor(),
127
+ ])
128
+ x = transforms(x)
129
+
130
+ assert (x.shape[1] == image_size) & (x.shape[2] == image_size), \
131
+ 'Wrong image size'
132
+ """
133
+ x = x * 2 - 1
134
+ return x
135
+
136
+
137
+ def center_crop(img, new_width=None, new_height=None):
138
+ width = img.shape[2]
139
+ height = img.shape[1]
140
+
141
+ if new_width is None:
142
+ new_width = min(width, height)
143
+
144
+ if new_height is None:
145
+ new_height = min(width, height)
146
+
147
+ left = int(np.ceil((width - new_width) / 2))
148
+ right = width - int(np.floor((width - new_width) / 2))
149
+
150
+ top = int(np.ceil((height - new_height) / 2))
151
+ bottom = height - int(np.floor((height - new_height) / 2))
152
+ if len(img.shape) == 3:
153
+ center_cropped_img = img[:, top:bottom, left:right]
154
+ else:
155
+ center_cropped_img = img[:, top:bottom, left:right, ...]
156
+
157
+ return center_cropped_img
158
+
159
+
160
+ def _transform(n_px):
161
+ return Compose([
162
+ Resize([n_px, n_px], interpolation=T.InterpolationMode.BICUBIC), ])
163
+
164
+
165
+ def regularize_video(video, image_size=256):
166
+ min_shape = min(video.shape[1:3])
167
+ video = center_crop(video, min_shape, min_shape)
168
+ video = torch.from_numpy(video).permute(0, 3, 1, 2)
169
+ video = _transform(image_size)(video)
170
+ video = video / 255.0 * 2.0 - 1.0
171
+ return video.permute(1, 0, 2, 3)
172
+
173
+
174
+ def time_to_indices(video_reader, time):
175
+ times = video_reader.get_frame_timestamp(range(len(video_reader))).mean(-1)
176
+ indices = np.searchsorted(times, time)
177
+ # Use `np.bitwise_or` so it works both with scalars and numpy arrays.
178
+ return np.where(np.bitwise_or(indices == 0, times[indices] - time <= time - times[indices - 1]), indices,
179
+ indices - 1)
180
+
181
+
182
+ def load_video(video_path, sample_duration=8.0, num_frames=8):
183
+ sample_duration = 4.0
184
+ num_frames = 4
185
+
186
+ vr = VideoReader(video_path, ctx=cpu(0))
187
+ framerate = vr.get_avg_fps()
188
+ video_frame_len = len(vr)
189
+ video_len = video_frame_len / framerate
190
+ sample_duration = min(sample_duration, video_len)
191
+
192
+ if video_len > sample_duration:
193
+ s = random.random() * (video_len - sample_duration)
194
+ t = s + sample_duration
195
+ start, end = time_to_indices(vr, [s, t])
196
+ end = min(video_frame_len - 1, end)
197
+ start = min(start, end - 1)
198
+ downsamlp_indices = np.linspace(start, end, num_frames, endpoint=True).astype(int).tolist()
199
+ else:
200
+ downsamlp_indices = np.linspace(0, video_frame_len - 1, num_frames, endpoint=True).astype(int).tolist()
201
+
202
+ video = vr.get_batch(downsamlp_indices).asnumpy()
203
+ return video
204
+
205
+
206
+ ###############
207
+ # some helper #
208
+ ###############
209
+
210
+ def atomic_save(cfg, net, opt, step, path):
211
+ if isinstance(net, (torch.nn.DataParallel,
212
+ torch.nn.parallel.DistributedDataParallel)):
213
+ netm = net.module
214
+ else:
215
+ netm = net
216
+ sd = netm.state_dict()
217
+ slimmed_sd = [(ki, vi) for ki, vi in sd.items()
218
+ if ki.find('first_stage_model') != 0 and ki.find('cond_stage_model') != 0]
219
+
220
+ checkpoint = {
221
+ "config": cfg,
222
+ "state_dict": OrderedDict(slimmed_sd),
223
+ "step": step}
224
+ if opt is not None:
225
+ checkpoint['optimizer_states'] = opt.state_dict()
226
+ import io
227
+ import fsspec
228
+ bytesbuffer = io.BytesIO()
229
+ torch.save(checkpoint, bytesbuffer)
230
+ with fsspec.open(path, "wb") as f:
231
+ f.write(bytesbuffer.getvalue())
232
+
233
+
234
+ def load_state_dict(net, cfg):
235
+ pretrained_pth_full = cfg.get('pretrained_pth_full', None)
236
+ pretrained_ckpt_full = cfg.get('pretrained_ckpt_full', None)
237
+ pretrained_pth = cfg.get('pretrained_pth', None)
238
+ pretrained_ckpt = cfg.get('pretrained_ckpt', None)
239
+ pretrained_pth_dm = cfg.get('pretrained_pth_dm', None)
240
+ pretrained_pth_ema = cfg.get('pretrained_pth_ema', None)
241
+ strict_sd = cfg.get('strict_sd', False)
242
+ errmsg = "Overlapped model state_dict! This is undesired behavior!"
243
+
244
+ if pretrained_pth_full is not None or pretrained_ckpt_full is not None:
245
+ assert (pretrained_pth is None) and \
246
+ (pretrained_ckpt is None) and \
247
+ (pretrained_pth_dm is None) and \
248
+ (pretrained_pth_ema is None), errmsg
249
+ if pretrained_pth_full is not None:
250
+ target_file = pretrained_pth_full
251
+ sd = torch.load(target_file, map_location='cpu')
252
+ assert pretrained_ckpt is None, errmsg
253
+ else:
254
+ target_file = pretrained_ckpt_full
255
+ sd = torch.load(target_file, map_location='cpu')['state_dict']
256
+ print('Load full model from [{}] strict [{}].'.format(
257
+ target_file, strict_sd))
258
+ net.load_state_dict(sd, strict=strict_sd)
259
+
260
+ if pretrained_pth is not None or pretrained_ckpt is not None:
261
+ assert (pretrained_ckpt_full is None) and \
262
+ (pretrained_pth_full is None) and \
263
+ (pretrained_pth_dm is None) and \
264
+ (pretrained_pth_ema is None), errmsg
265
+ if pretrained_pth is not None:
266
+ target_file = pretrained_pth
267
+ sd = torch.load(target_file, map_location='cpu')
268
+ assert pretrained_ckpt is None, errmsg
269
+ else:
270
+ target_file = pretrained_ckpt
271
+ sd = torch.load(target_file, map_location='cpu')['state_dict']
272
+ print('Load model from [{}] strict [{}].'.format(
273
+ target_file, strict_sd))
274
+ sd_extra = [(ki, vi) for ki, vi in net.state_dict().items() \
275
+ if ki.find('first_stage_model') == 0 or ki.find('cond_stage_model') == 0]
276
+ sd.update(OrderedDict(sd_extra))
277
+ net.load_state_dict(sd, strict=strict_sd)
278
+
279
+ if pretrained_pth_dm is not None:
280
+ assert (pretrained_ckpt_full is None) and \
281
+ (pretrained_pth_full is None) and \
282
+ (pretrained_pth is None) and \
283
+ (pretrained_ckpt is None), errmsg
284
+ print('Load diffusion model from [{}] strict [{}].'.format(
285
+ pretrained_pth_dm, strict_sd))
286
+ sd = torch.load(pretrained_pth_dm, map_location='cpu')
287
+ net.model.diffusion_model.load_state_dict(sd, strict=strict_sd)
288
+
289
+ if pretrained_pth_ema is not None:
290
+ assert (pretrained_ckpt_full is None) and \
291
+ (pretrained_pth_full is None) and \
292
+ (pretrained_pth is None) and \
293
+ (pretrained_ckpt is None), errmsg
294
+ print('Load unet ema model from [{}] strict [{}].'.format(
295
+ pretrained_pth_ema, strict_sd))
296
+ sd = torch.load(pretrained_pth_ema, map_location='cpu')
297
+ net.model_ema.load_state_dict(sd, strict=strict_sd)
298
+
299
+
300
+ def auto_merge_imlist(imlist, max=64):
301
+ imlist = imlist[0:max]
302
+ h, w = imlist[0].shape[0:2]
303
+ num_images = len(imlist)
304
+ num_row = int(np.sqrt(num_images))
305
+ num_col = num_images // num_row + 1 if num_images % num_row != 0 else num_images // num_row
306
+ canvas = np.zeros([num_row * h, num_col * w, 3], dtype=np.uint8)
307
+ for idx, im in enumerate(imlist):
308
+ hi = (idx // num_col) * h
309
+ wi = (idx % num_col) * w
310
+ canvas[hi:hi + h, wi:wi + w, :] = im
311
+ return canvas
312
+
313
+
314
+ def latent2im(net, latent):
315
+ single_input = len(latent.shape) == 3
316
+ if single_input:
317
+ latent = latent[None]
318
+ im = net.decode_image(latent.to(net.device))
319
+ im = torch.clamp((im + 1.0) / 2.0, min=0.0, max=1.0)
320
+ im = [tvtrans.ToPILImage()(i) for i in im]
321
+ if single_input:
322
+ im = im[0]
323
+ return im
324
+
325
+
326
+ def im2latent(net, im):
327
+ single_input = not isinstance(im, list)
328
+ if single_input:
329
+ im = [im]
330
+ im = torch.stack([tvtrans.ToTensor()(i) for i in im], dim=0)
331
+ im = (im * 2 - 1).to(net.device)
332
+ z = net.encode_image(im)
333
+ if single_input:
334
+ z = z[0]
335
+ return z
336
+
337
+
338
+ class color_adjust(object):
339
+ def __init__(self, ref_from, ref_to):
340
+ x0, m0, std0 = self.get_data_and_stat(ref_from)
341
+ x1, m1, std1 = self.get_data_and_stat(ref_to)
342
+ self.ref_from_stat = (m0, std0)
343
+ self.ref_to_stat = (m1, std1)
344
+ self.ref_from = self.preprocess(x0).reshape(-1, 3)
345
+ self.ref_to = x1.reshape(-1, 3)
346
+
347
+ def get_data_and_stat(self, x):
348
+ if isinstance(x, str):
349
+ x = np.array(PIL.Image.open(x))
350
+ elif isinstance(x, PIL.Image.Image):
351
+ x = np.array(x)
352
+ elif isinstance(x, torch.Tensor):
353
+ x = torch.clamp(x, min=0.0, max=1.0)
354
+ x = np.array(tvtrans.ToPILImage()(x))
355
+ elif isinstance(x, np.ndarray):
356
+ pass
357
+ else:
358
+ raise ValueError
359
+ x = x.astype(float)
360
+ m = np.reshape(x, (-1, 3)).mean(0)
361
+ s = np.reshape(x, (-1, 3)).std(0)
362
+ return x, m, s
363
+
364
+ def preprocess(self, x):
365
+ m0, s0 = self.ref_from_stat
366
+ m1, s1 = self.ref_to_stat
367
+ y = ((x - m0) / s0) * s1 + m1
368
+ return y
369
+
370
+ def __call__(self, xin, keep=0, simple=False):
371
+ xin, _, _ = self.get_data_and_stat(xin)
372
+ x = self.preprocess(xin)
373
+ if simple:
374
+ y = (x * (1 - keep) + xin * keep)
375
+ y = np.clip(y, 0, 255).astype(np.uint8)
376
+ return y
377
+
378
+ h, w = x.shape[:2]
379
+ x = x.reshape(-1, 3)
380
+ y = []
381
+ for chi in range(3):
382
+ yi = self.pdf_transfer_1d(self.ref_from[:, chi], self.ref_to[:, chi], x[:, chi])
383
+ y.append(yi)
384
+
385
+ y = np.stack(y, axis=1)
386
+ y = y.reshape(h, w, 3)
387
+ y = (y.astype(float) * (1 - keep) + xin.astype(float) * keep)
388
+ y = np.clip(y, 0, 255).astype(np.uint8)
389
+ return y
390
+
391
+ def pdf_transfer_1d(self, arr_fo, arr_to, arr_in, n=600):
392
+ arr = np.concatenate((arr_fo, arr_to))
393
+ min_v = arr.min() - 1e-6
394
+ max_v = arr.max() + 1e-6
395
+ min_vto = arr_to.min() - 1e-6
396
+ max_vto = arr_to.max() + 1e-6
397
+ xs = np.array(
398
+ [min_v + (max_v - min_v) * i / n for i in range(n + 1)])
399
+ hist_fo, _ = np.histogram(arr_fo, xs)
400
+ hist_to, _ = np.histogram(arr_to, xs)
401
+ xs = xs[:-1]
402
+ # compute probability distribution
403
+ cum_fo = np.cumsum(hist_fo)
404
+ cum_to = np.cumsum(hist_to)
405
+ d_fo = cum_fo / cum_fo[-1]
406
+ d_to = cum_to / cum_to[-1]
407
+ # transfer
408
+ t_d = np.interp(d_fo, d_to, xs)
409
+ t_d[d_fo <= d_to[0]] = min_vto
410
+ t_d[d_fo >= d_to[-1]] = max_vto
411
+ arr_out = np.interp(arr_in, xs, t_d)
412
+ return arr_out
core/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .common.get_model import get_model
2
+ from .common.get_optimizer import get_optimizer
3
+ from .common.get_scheduler import get_scheduler
4
+ from .common.utils import get_unit
core/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (367 Bytes). View file
 
core/models/__pycache__/codi.cpython-38.pyc ADDED
Binary file (7.7 kB). View file
 
core/models/__pycache__/codi_2.cpython-38.pyc ADDED
Binary file (7.12 kB). View file
 
core/models/__pycache__/dani_model.cpython-38.pyc ADDED
Binary file (4.29 kB). View file
 
core/models/__pycache__/ema.cpython-38.pyc ADDED
Binary file (2.99 kB). View file
 
core/models/__pycache__/model_module_infer.cpython-38.pyc ADDED
Binary file (4.31 kB). View file
 
core/models/__pycache__/sd.cpython-38.pyc ADDED
Binary file (9.82 kB). View file
 
core/models/codi.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import numpy.random as npr
9
+ import copy
10
+ from functools import partial
11
+ from contextlib import contextmanager
12
+
13
+ from .common.get_model import get_model, register
14
+ from .sd import DDPM
15
+
16
+ version = '0'
17
+ symbol = 'codi'
18
+
19
+
20
+ @register('codi', version)
21
+ class CoDi(DDPM):
22
+ def __init__(self,
23
+ audioldm_cfg=None,
24
+ autokl_cfg=None,
25
+ optimus_cfg=None,
26
+ clip_cfg=None,
27
+ clap_cfg=None,
28
+ vision_scale_factor=0.1812,
29
+ text_scale_factor=4.3108,
30
+ audio_scale_factor=0.9228,
31
+ scale_by_std=False,
32
+ *args,
33
+ **kwargs):
34
+ super().__init__(*args, **kwargs)
35
+
36
+ if audioldm_cfg is not None:
37
+ self.audioldm = get_model()(audioldm_cfg)
38
+
39
+ if autokl_cfg is not None:
40
+ self.autokl = get_model()(autokl_cfg)
41
+
42
+ if optimus_cfg is not None:
43
+ self.optimus = get_model()(optimus_cfg)
44
+
45
+ if clip_cfg is not None:
46
+ self.clip = get_model()(clip_cfg)
47
+
48
+ if clap_cfg is not None:
49
+ self.clap = get_model()(clap_cfg)
50
+
51
+ if not scale_by_std:
52
+ self.vision_scale_factor = vision_scale_factor
53
+ self.text_scale_factor = text_scale_factor
54
+ self.audio_scale_factor = audio_scale_factor
55
+ else:
56
+ self.register_buffer("text_scale_factor", torch.tensor(text_scale_factor))
57
+ self.register_buffer("audio_scale_factor", torch.tensor(audio_scale_factor))
58
+ self.register_buffer('vision_scale_factor', torch.tensor(vision_scale_factor))
59
+
60
+ @property
61
+ def device(self):
62
+ return next(self.parameters()).device
63
+
64
+ @torch.no_grad()
65
+ def autokl_encode(self, image):
66
+ encoder_posterior = self.autokl.encode(image)
67
+ z = encoder_posterior.sample().to(image.dtype)
68
+ return self.vision_scale_factor * z
69
+
70
+ @torch.no_grad()
71
+ def autokl_decode(self, z):
72
+ z = 1. / self.vision_scale_factor * z
73
+ return self.autokl.decode(z)
74
+
75
+ @torch.no_grad()
76
+ def optimus_encode(self, text):
77
+ if isinstance(text, List):
78
+ tokenizer = self.optimus.tokenizer_encoder
79
+ token = [tokenizer.tokenize(sentence.lower()) for sentence in text]
80
+ token_id = []
81
+ for tokeni in token:
82
+ token_sentence = [tokenizer._convert_token_to_id(i) for i in tokeni]
83
+ token_sentence = tokenizer.add_special_tokens_single_sentence(token_sentence)
84
+ token_id.append(torch.LongTensor(token_sentence))
85
+ token_id = torch._C._nn.pad_sequence(token_id, batch_first=True, padding_value=0.0)[:, :512]
86
+ else:
87
+ token_id = text
88
+ z = self.optimus.encoder(token_id, attention_mask=(token_id > 0))[1]
89
+ z_mu, z_logvar = self.optimus.encoder.linear(z).chunk(2, -1)
90
+ return z_mu.squeeze(1) * self.text_scale_factor
91
+
92
+ @torch.no_grad()
93
+ def optimus_decode(self, z, temperature=1.0, max_length=30):
94
+ z = 1.0 / self.text_scale_factor * z
95
+ return self.optimus.decode(z, temperature, max_length=max_length)
96
+
97
+ @torch.no_grad()
98
+ def audioldm_encode(self, audio, time=2.0):
99
+ encoder_posterior = self.audioldm.encode(audio, time=time)
100
+ z = encoder_posterior.sample().to(audio.dtype)
101
+ return z * self.audio_scale_factor
102
+
103
+ @torch.no_grad()
104
+ def audioldm_decode(self, z):
105
+ if torch.max(torch.abs(z)) > 1e2:
106
+ z = torch.clip(z, min=-10, max=10)
107
+ z = 1.0 / self.audio_scale_factor * z
108
+ return self.audioldm.decode(z)
109
+
110
+ @torch.no_grad()
111
+ def mel_spectrogram_to_waveform(self, mel):
112
+ # Mel: [bs, 1, t-steps, fbins]
113
+ if len(mel.size()) == 4:
114
+ mel = mel.squeeze(1)
115
+ mel = mel.permute(0, 2, 1)
116
+ waveform = self.audioldm.vocoder(mel)
117
+ waveform = waveform.cpu().detach().numpy()
118
+ return waveform
119
+
120
+ @torch.no_grad()
121
+ def clip_encode_text(self, text, encode_type='encode_text'):
122
+ swap_type = self.clip.encode_type
123
+ self.clip.encode_type = encode_type
124
+ embedding = self.clip(text, encode_type)
125
+ self.clip.encode_type = swap_type
126
+ return embedding
127
+
128
+ @torch.no_grad()
129
+ def clip_encode_vision(self, vision, encode_type='encode_vision'):
130
+ swap_type = self.clip.encode_type
131
+ self.clip.encode_type = encode_type
132
+ embedding = self.clip(vision, encode_type)
133
+ self.clip.encode_type = swap_type
134
+ return embedding
135
+
136
+ @torch.no_grad()
137
+ def clap_encode_audio(self, audio):
138
+ embedding = self.clap(audio)
139
+ return embedding
140
+
141
+ def forward(self, x=None, c=None, noise=None, xtype='image', ctype='prompt', u=None, return_algined_latents=False):
142
+ if isinstance(x, list):
143
+ t = torch.randint(0, self.num_timesteps, (x[0].shape[0],), device=x[0].device).long()
144
+ else:
145
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=x.device).long()
146
+ return self.p_losses(x, c, t, noise, xtype, ctype, u, return_algined_latents)
147
+
148
+ def apply_model(self, x_noisy, t, cond, xtype='image', ctype='text', u=None, return_algined_latents=False):
149
+ return self.model.diffusion_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents)
150
+
151
+ def get_pixel_loss(self, pred, target, mean=True):
152
+ if self.loss_type == 'l1':
153
+ loss = (target - pred).abs()
154
+ if mean:
155
+ loss = loss.mean()
156
+ elif self.loss_type == 'l2':
157
+ if mean:
158
+ loss = torch.nn.functional.mse_loss(target, pred)
159
+ else:
160
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
161
+ else:
162
+ raise NotImplementedError("unknown loss type '{loss_type}'")
163
+ loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=-0.0)
164
+ return loss
165
+
166
+ def get_text_loss(self, pred, target):
167
+ if self.loss_type == 'l1':
168
+ loss = (target - pred).abs()
169
+ elif self.loss_type == 'l2':
170
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
171
+ loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=0.0)
172
+ return loss
173
+
174
+ def p_losses(self, x_start, cond, t, noise=None, xtype='image', ctype='prompt', u=None, return_algined_latents=False):
175
+ if isinstance(x_start, list):
176
+ noise = [torch.randn_like(x_start_i) for x_start_i in x_start] if noise is None else noise
177
+ x_noisy = [self.q_sample(x_start=x_start_i, t=t, noise=noise_i) for x_start_i, noise_i in zip(x_start, noise)]
178
+ model_output = self.apply_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents)
179
+ if return_algined_latents:
180
+ return model_output
181
+
182
+ loss_dict = {}
183
+
184
+ if self.parameterization == "x0":
185
+ target = x_start
186
+ elif self.parameterization == "eps":
187
+ target = noise
188
+ else:
189
+ raise NotImplementedError()
190
+
191
+ loss = 0.0
192
+ for model_output_i, target_i, xtype_i in zip(model_output, target, xtype):
193
+ if xtype_i == 'image':
194
+ loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3])
195
+ elif xtype_i == 'video':
196
+ loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3, 4])
197
+ elif xtype_i == 'text':
198
+ loss_simple = self.get_text_loss(model_output_i, target_i).mean([1])
199
+ elif xtype_i == 'audio':
200
+ loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3])
201
+ loss += loss_simple.mean()
202
+ return loss / len(xtype)
203
+
204
+ else:
205
+ noise = torch.randn_like(x_start) if noise is None else noise
206
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
207
+ model_output = self.apply_model(x_noisy, t, cond, xtype, ctype)
208
+
209
+ loss_dict = {}
210
+
211
+ if self.parameterization == "x0":
212
+ target = x_start
213
+ elif self.parameterization == "eps":
214
+ target = noise
215
+ else:
216
+ raise NotImplementedError()
217
+
218
+ if xtype == 'image':
219
+ loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3])
220
+ elif xtype == 'video':
221
+ loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3, 4])
222
+ elif xtype == 'text':
223
+ loss_simple = self.get_text_loss(model_output, target).mean([1])
224
+ elif xtype == 'audio':
225
+ loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3])
226
+ loss = loss_simple.mean()
227
+ return loss
core/models/codi_2.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ import numpy.random as npr
9
+ import copy
10
+ from functools import partial
11
+ from contextlib import contextmanager
12
+
13
+ from .common.get_model import get_model, register
14
+ from .sd import DDPM
15
+
16
+ version = '0'
17
+ symbol = 'thesis_model'
18
+
19
+
20
+ @register('thesis_model', version)
21
+ class CoDi(DDPM):
22
+ def __init__(self,
23
+ autokl_cfg=None,
24
+ optimus_cfg=None,
25
+ clip_cfg=None,
26
+ vision_scale_factor=0.1812,
27
+ text_scale_factor=4.3108,
28
+ audio_scale_factor=0.9228,
29
+ scale_by_std=False,
30
+ *args,
31
+ **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+
34
+ if autokl_cfg is not None:
35
+ self.autokl = get_model()(autokl_cfg)
36
+
37
+ if optimus_cfg is not None:
38
+ self.optimus = get_model()(optimus_cfg)
39
+
40
+ if clip_cfg is not None:
41
+ self.clip = get_model()(clip_cfg)
42
+
43
+ if not scale_by_std:
44
+ self.vision_scale_factor = vision_scale_factor
45
+ self.text_scale_factor = text_scale_factor
46
+ self.audio_scale_factor = audio_scale_factor
47
+ else:
48
+ self.register_buffer("text_scale_factor", torch.tensor(text_scale_factor))
49
+ self.register_buffer("audio_scale_factor", torch.tensor(audio_scale_factor))
50
+ self.register_buffer('vision_scale_factor', torch.tensor(vision_scale_factor))
51
+
52
+ @property
53
+ def device(self):
54
+ return next(self.parameters()).device
55
+
56
+ @torch.no_grad()
57
+ def autokl_encode(self, image):
58
+ encoder_posterior = self.autokl.encode(image)
59
+ z = encoder_posterior.sample().to(image.dtype)
60
+ return self.vision_scale_factor * z
61
+
62
+ @torch.no_grad()
63
+ def autokl_decode(self, z):
64
+ z = 1. / self.vision_scale_factor * z
65
+ return self.autokl.decode(z)
66
+
67
+ @torch.no_grad()
68
+ def optimus_encode(self, text):
69
+ if isinstance(text, List):
70
+ tokenizer = self.optimus.tokenizer_encoder
71
+ token = [tokenizer.tokenize(sentence.lower()) for sentence in text]
72
+ token_id = []
73
+ for tokeni in token:
74
+ token_sentence = [tokenizer._convert_token_to_id(i) for i in tokeni]
75
+ token_sentence = tokenizer.add_special_tokens_single_sentence(token_sentence)
76
+ token_id.append(torch.LongTensor(token_sentence))
77
+ token_id = torch._C._nn.pad_sequence(token_id, batch_first=True, padding_value=0.0)[:, :512]
78
+ else:
79
+ token_id = text
80
+ z = self.optimus.encoder(token_id, attention_mask=(token_id > 0))[1]
81
+ z_mu, z_logvar = self.optimus.encoder.linear(z).chunk(2, -1)
82
+ return z_mu.squeeze(1) * self.text_scale_factor
83
+
84
+ @torch.no_grad()
85
+ def optimus_decode(self, z, temperature=1.0):
86
+ z = 1.0 / self.text_scale_factor * z
87
+ return self.optimus.decode(z, temperature)
88
+
89
+ @torch.no_grad()
90
+ def clip_encode_text(self, text, encode_type='encode_text'):
91
+ swap_type = self.clip.encode_type
92
+ self.clip.encode_type = encode_type
93
+ embedding = self.clip(text, encode_type)
94
+ self.clip.encode_type = swap_type
95
+ return embedding
96
+
97
+ @torch.no_grad()
98
+ def clip_encode_vision(self, vision, encode_type='encode_vision'):
99
+ swap_type = self.clip.encode_type
100
+ self.clip.encode_type = encode_type
101
+ embedding = self.clip(vision, encode_type)
102
+ self.clip.encode_type = swap_type
103
+ return embedding
104
+
105
+ @torch.no_grad()
106
+ def clap_encode_audio(self, audio):
107
+ embedding = self.clap(audio)
108
+ return embedding
109
+
110
+ def forward(self, x=None, c=None, noise=None, xtype='frontal', ctype='text', u=None, return_algined_latents=False, env_enc=False):
111
+ if isinstance(x, list):
112
+ t = torch.randint(0, self.num_timesteps, (x[0].shape[0],), device=x[0].device).long()
113
+ else:
114
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=x.device).long()
115
+ return self.p_losses(x, c, t, noise, xtype, ctype, u, return_algined_latents, env_enc)
116
+
117
+ def apply_model(self, x_noisy, t, cond, xtype='frontal', ctype='text', u=None, return_algined_latents=False, env_enc=False):
118
+ return self.model.diffusion_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents, env_enc=env_enc)
119
+
120
+ def get_pixel_loss(self, pred, target, mean=True):
121
+ if self.loss_type == 'l1':
122
+ loss = (target - pred).abs()
123
+ if mean:
124
+ loss = loss.mean()
125
+ elif self.loss_type == 'l2':
126
+ if mean:
127
+ loss = torch.nn.functional.mse_loss(target, pred)
128
+ else:
129
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
130
+ else:
131
+ raise NotImplementedError("unknown loss type '{loss_type}'")
132
+ loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=-0.0)
133
+ return loss
134
+
135
+ def get_text_loss(self, pred, target):
136
+ if self.loss_type == 'l1':
137
+ loss = (target - pred).abs()
138
+ elif self.loss_type == 'l2':
139
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
140
+ loss = torch.nan_to_num(loss, nan=0.0, posinf=0.0, neginf=0.0)
141
+ return loss
142
+
143
+ def p_losses(self, x_start, cond, t, noise=None, xtype='frontal', ctype='text', u=None,
144
+ return_algined_latents=False, env_enc=False):
145
+ if isinstance(x_start, list):
146
+ noise = [torch.randn_like(x_start_i) for x_start_i in x_start] if noise is None else noise
147
+ x_noisy = [self.q_sample(x_start=x_start_i, t=t, noise=noise_i) for x_start_i, noise_i in
148
+ zip(x_start, noise)]
149
+ if not env_enc:
150
+ model_output = self.apply_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents, env_enc)
151
+ else:
152
+ model_output, h_con = self.apply_model(x_noisy, t, cond, xtype, ctype, u, return_algined_latents, env_enc)
153
+ if return_algined_latents:
154
+ return model_output
155
+
156
+ loss_dict = {}
157
+
158
+ if self.parameterization == "x0":
159
+ target = x_start
160
+ elif self.parameterization == "eps":
161
+ target = noise
162
+ else:
163
+ raise NotImplementedError()
164
+
165
+ loss = 0.0
166
+ for model_output_i, target_i, xtype_i in zip(model_output, target, xtype):
167
+ if xtype_i == 'frontal':
168
+ loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3])
169
+ elif xtype_i == 'text':
170
+ loss_simple = self.get_text_loss(model_output_i, target_i).mean([1])
171
+ elif xtype_i == 'lateral':
172
+ loss_simple = self.get_pixel_loss(model_output_i, target_i, mean=False).mean([1, 2, 3])
173
+ loss += loss_simple.mean()
174
+
175
+ # Controlliamo se il modello ha restituito anche h_con
176
+ # In tal caso, abbiamo le rappresentazioni latenti delle due modalità
177
+ # estratte dagli environmental encoder, essendo due tensori di dimensione batch_sizex1x1280
178
+ # possiamo utilizzarli per calcolare anche un termine di contrastive loss (crossentropy come in CLIP)
179
+ if h_con is not None:
180
+ def similarity(z_a, z_b):
181
+ return F.cosine_similarity(z_a, z_b)
182
+
183
+ z_a, z_b = h_con
184
+
185
+ z_a = z_a / z_a.norm(dim=-1, keepdim=True)
186
+ z_b = z_b / z_b.norm(dim=-1, keepdim=True)
187
+
188
+ logits_a = z_a.squeeze() @ z_b.squeeze().t()
189
+ logits_b = z_a.squeeze() @ z_b.squeeze().t()
190
+
191
+ labels = torch.arange(len(z_a)).to(z_a.device)
192
+
193
+ loss_a = F.cross_entropy(logits_a, labels)
194
+ loss_b = F.cross_entropy(logits_b, labels)
195
+
196
+ loss_con = (loss_a + loss_b) / 2
197
+ loss += loss_con
198
+ return loss / len(xtype)
199
+
200
+ else:
201
+ noise = torch.randn_like(x_start) if noise is None else noise
202
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
203
+ model_output = self.apply_model(x_noisy, t, cond, xtype, ctype)
204
+
205
+ loss_dict = {}
206
+
207
+ if self.parameterization == "x0":
208
+ target = x_start
209
+ elif self.parameterization == "eps":
210
+ target = noise
211
+ else:
212
+ raise NotImplementedError()
213
+
214
+ if xtype == 'frontal':
215
+ loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3])
216
+ elif xtype == 'text':
217
+ loss_simple = self.get_text_loss(model_output, target).mean([1])
218
+ elif xtype == 'lateral':
219
+ loss_simple = self.get_pixel_loss(model_output, target, mean=False).mean([1, 2, 3])
220
+ loss = loss_simple.mean()
221
+ return loss
core/models/common/__pycache__/get_model.cpython-38.pyc ADDED
Binary file (2.96 kB). View file
 
core/models/common/__pycache__/get_optimizer.cpython-38.pyc ADDED
Binary file (1.94 kB). View file
 
core/models/common/__pycache__/get_scheduler.cpython-38.pyc ADDED
Binary file (9.55 kB). View file
 
core/models/common/__pycache__/utils.cpython-38.pyc ADDED
Binary file (9.75 kB). View file
 
core/models/common/get_model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from email.policy import strict
2
+ import torch
3
+ import torchvision.models
4
+ import os.path as osp
5
+ import copy
6
+ from .utils import \
7
+ get_total_param, get_total_param_sum, \
8
+ get_unit
9
+
10
+
11
+ def singleton(class_):
12
+ instances = {}
13
+
14
+ def getinstance(*args, **kwargs):
15
+ if class_ not in instances:
16
+ instances[class_] = class_(*args, **kwargs)
17
+ return instances[class_]
18
+ return getinstance
19
+
20
+
21
+ def preprocess_model_args(args):
22
+ # If args has layer_units, get the corresponding
23
+ # units.
24
+ # If args get backbone, get the backbone model.
25
+ args = copy.deepcopy(args)
26
+ if 'layer_units' in args:
27
+ layer_units = [
28
+ get_unit()(i) for i in args.layer_units
29
+ ]
30
+ args.layer_units = layer_units
31
+ if 'backbone' in args:
32
+ args.backbone = get_model()(args.backbone)
33
+ return args
34
+
35
+ @singleton
36
+ class get_model(object):
37
+ def __init__(self):
38
+ self.model = {}
39
+ self.version = {}
40
+
41
+ def register(self, model, name, version='x'):
42
+ self.model[name] = model
43
+ self.version[name] = version
44
+
45
+ def __call__(self, cfg, verbose=True):
46
+ """
47
+ Construct model based on the config.
48
+ """
49
+ t = cfg.type
50
+
51
+ # the register is in each file
52
+ if t.find('audioldm')==0:
53
+ from ..latent_diffusion.vae import audioldm
54
+ elif t.find('autoencoderkl')==0:
55
+ from ..latent_diffusion.vae import autokl
56
+ elif t.find('optimus')==0:
57
+ from ..latent_diffusion.vae import optimus
58
+
59
+ elif t.find('clip')==0:
60
+ from ..encoders import clip
61
+ elif t.find('clap')==0:
62
+ from ..encoders import clap
63
+
64
+ elif t.find('sd')==0:
65
+ from .. import sd
66
+ elif t.find('codi')==0:
67
+ from .. import codi
68
+ elif t.find('thesis_model')==0:
69
+ from .. import codi_2
70
+ elif t.find('openai_unet')==0:
71
+ from ..latent_diffusion import diffusion_unet
72
+ elif t.find('prova')==0:
73
+ from ..latent_diffusion import diffusion_unet
74
+
75
+ args = preprocess_model_args(cfg.args)
76
+ net = self.model[t](**args)
77
+
78
+ return net
79
+
80
+ def get_version(self, name):
81
+ return self.version[name]
82
+
83
+
84
+ def register(name, version='x'):
85
+ def wrapper(class_):
86
+ get_model().register(class_, name, version)
87
+ return class_
88
+ return wrapper
core/models/common/get_optimizer.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import numpy as np
4
+ import itertools
5
+
6
+
7
+ def singleton(class_):
8
+ instances = {}
9
+
10
+ def getinstance(*args, **kwargs):
11
+ if class_ not in instances:
12
+ instances[class_] = class_(*args, **kwargs)
13
+ return instances[class_]
14
+ return getinstance
15
+
16
+
17
+ class get_optimizer(object):
18
+ def __init__(self):
19
+ self.optimizer = {}
20
+ self.register(optim.SGD, 'sgd')
21
+ self.register(optim.Adam, 'adam')
22
+ self.register(optim.AdamW, 'adamw')
23
+
24
+ def register(self, optim, name):
25
+ self.optimizer[name] = optim
26
+
27
+ def __call__(self, net, cfg):
28
+ if cfg is None:
29
+ return None
30
+ t = cfg.type
31
+ if isinstance(net, (torch.nn.DataParallel,
32
+ torch.nn.parallel.DistributedDataParallel)):
33
+ netm = net.module
34
+ else:
35
+ netm = net
36
+ pg = getattr(netm, 'parameter_group', None)
37
+
38
+ if pg is not None:
39
+ params = []
40
+ for group_name, module_or_para in pg.items():
41
+ if not isinstance(module_or_para, list):
42
+ module_or_para = [module_or_para]
43
+
44
+ grouped_params = [mi.parameters() if isinstance(mi, torch.nn.Module) else [mi] for mi in module_or_para]
45
+ grouped_params = itertools.chain(*grouped_params)
46
+ pg_dict = {'params': grouped_params, 'name': group_name}
47
+ params.append(pg_dict)
48
+ else:
49
+ params = net.parameters()
50
+ return self.optimizer[t](params, lr=0, **cfg.args)
core/models/common/get_scheduler.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ import numpy as np
4
+ import copy
5
+ from ... import sync
6
+ from ...cfg_holder import cfg_unique_holder as cfguh
7
+
8
+
9
+ def singleton(class_):
10
+ instances = {}
11
+
12
+ def getinstance(*args, **kwargs):
13
+ if class_ not in instances:
14
+ instances[class_] = class_(*args, **kwargs)
15
+ return instances[class_]
16
+ return getinstance
17
+
18
+
19
+ @singleton
20
+ class get_scheduler(object):
21
+ def __init__(self):
22
+ self.lr_scheduler = {}
23
+
24
+ def register(self, lrsf, name):
25
+ self.lr_scheduler[name] = lrsf
26
+
27
+ def __call__(self, cfg):
28
+ if cfg is None:
29
+ return None
30
+ if isinstance(cfg, list):
31
+ schedulers = []
32
+ for ci in cfg:
33
+ t = ci.type
34
+ schedulers.append(
35
+ self.lr_scheduler[t](**ci.args))
36
+ if len(schedulers) == 0:
37
+ raise ValueError
38
+ else:
39
+ return compose_scheduler(schedulers)
40
+ t = cfg.type
41
+ return self.lr_scheduler[t](**cfg.args)
42
+
43
+
44
+ def register(name):
45
+ def wrapper(class_):
46
+ get_scheduler().register(class_, name)
47
+ return class_
48
+ return wrapper
49
+
50
+
51
+ class template_scheduler(object):
52
+ def __init__(self, step):
53
+ self.step = step
54
+
55
+ def __getitem__(self, idx):
56
+ raise ValueError
57
+
58
+ def set_lr(self, optim, new_lr, pg_lrscale=None):
59
+ """
60
+ Set Each parameter_groups in optim with new_lr
61
+ New_lr can be find according to the idx.
62
+ pg_lrscale tells how to scale each pg.
63
+ """
64
+ # new_lr = self.__getitem__(idx)
65
+ pg_lrscale = copy.deepcopy(pg_lrscale)
66
+ for pg in optim.param_groups:
67
+ if pg_lrscale is None:
68
+ pg['lr'] = new_lr
69
+ else:
70
+ pg['lr'] = new_lr * pg_lrscale.pop(pg['name'])
71
+ assert (pg_lrscale is None) or (len(pg_lrscale)==0), \
72
+ "pg_lrscale doesn't match pg"
73
+
74
+ @register('constant')
75
+ class constant_scheduler(template_scheduler):
76
+ def __init__(self, lr, step):
77
+ super().__init__(step)
78
+ self.lr = lr
79
+
80
+ def __getitem__(self, idx):
81
+ if idx >= self.step:
82
+ raise ValueError
83
+ return self.lr
84
+
85
+
86
+ @register('poly')
87
+ class poly_scheduler(template_scheduler):
88
+ def __init__(self, start_lr, end_lr, power, step):
89
+ super().__init__(step)
90
+ self.start_lr = start_lr
91
+ self.end_lr = end_lr
92
+ self.power = power
93
+
94
+ def __getitem__(self, idx):
95
+ if idx >= self.step:
96
+ raise ValueError
97
+ a, b = self.start_lr, self.end_lr
98
+ p, n = self.power, self.step
99
+ return b + (a-b)*((1-idx/n)**p)
100
+
101
+
102
+ @register('linear')
103
+ class linear_scheduler(template_scheduler):
104
+ def __init__(self, start_lr, end_lr, step):
105
+ super().__init__(step)
106
+ self.start_lr = start_lr
107
+ self.end_lr = end_lr
108
+
109
+ def __getitem__(self, idx):
110
+ if idx >= self.step:
111
+ raise ValueError
112
+ a, b, n = self.start_lr, self.end_lr, self.step
113
+ return b + (a-b)*(1-idx/n)
114
+
115
+
116
+ @register('multistage')
117
+ class constant_scheduler(template_scheduler):
118
+ def __init__(self, start_lr, milestones, gamma, step):
119
+ super().__init__(step)
120
+ self.start_lr = start_lr
121
+ m = [0] + milestones + [step]
122
+ lr_iter = start_lr
123
+ self.lr = []
124
+ for ms, me in zip(m[0:-1], m[1:]):
125
+ for _ in range(ms, me):
126
+ self.lr.append(lr_iter)
127
+ lr_iter *= gamma
128
+
129
+ def __getitem__(self, idx):
130
+ if idx >= self.step:
131
+ raise ValueError
132
+ return self.lr[idx]
133
+
134
+
135
+ class compose_scheduler(template_scheduler):
136
+ def __init__(self, schedulers):
137
+ self.schedulers = schedulers
138
+ self.step = [si.step for si in schedulers]
139
+ self.step_milestone = []
140
+ acc = 0
141
+ for i in self.step:
142
+ acc += i
143
+ self.step_milestone.append(acc)
144
+ self.step = sum(self.step)
145
+
146
+ def __getitem__(self, idx):
147
+ if idx >= self.step:
148
+ raise ValueError
149
+ ms = self.step_milestone
150
+ for idx, (mi, mj) in enumerate(zip(ms[:-1], ms[1:])):
151
+ if mi <= idx < mj:
152
+ return self.schedulers[idx-mi]
153
+ raise ValueError
154
+
155
+ ####################
156
+ # lambda schedular #
157
+ ####################
158
+
159
+
160
+ class LambdaWarmUpCosineScheduler(template_scheduler):
161
+ """
162
+ note: use with a base_lr of 1.0
163
+ """
164
+ def __init__(self,
165
+ base_lr,
166
+ warm_up_steps,
167
+ lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
168
+ cfgt = cfguh().cfg.train
169
+ bs = cfgt.batch_size
170
+ if 'gradacc_every' not in cfgt:
171
+ print('Warning, gradacc_every is not found in xml, use 1 as default.')
172
+ acc = cfgt.get('gradacc_every', 1)
173
+ self.lr_multi = base_lr * bs * acc
174
+ self.lr_warm_up_steps = warm_up_steps
175
+ self.lr_start = lr_start
176
+ self.lr_min = lr_min
177
+ self.lr_max = lr_max
178
+ self.lr_max_decay_steps = max_decay_steps
179
+ self.last_lr = 0.
180
+ self.verbosity_interval = verbosity_interval
181
+
182
+ def schedule(self, n):
183
+ if self.verbosity_interval > 0:
184
+ if n % self.verbosity_interval == 0:
185
+ print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
186
+ if n < self.lr_warm_up_steps:
187
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
188
+ self.last_lr = lr
189
+ return lr
190
+ else:
191
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
192
+ t = min(t, 1.0)
193
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
194
+ 1 + np.cos(t * np.pi))
195
+ self.last_lr = lr
196
+ return lr
197
+
198
+ def __getitem__(self, idx):
199
+ return self.schedule(idx) * self.lr_multi
200
+
201
+
202
+ class LambdaWarmUpCosineScheduler2(template_scheduler):
203
+ """
204
+ supports repeated iterations, configurable via lists
205
+ note: use with a base_lr of 1.0.
206
+ """
207
+ def __init__(self,
208
+ base_lr,
209
+ warm_up_steps,
210
+ f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
211
+ cfgt = cfguh().cfg.train
212
+ # bs = cfgt.batch_size
213
+ # if 'gradacc_every' not in cfgt:
214
+ # print('Warning, gradacc_every is not found in xml, use 1 as default.')
215
+ # acc = cfgt.get('gradacc_every', 1)
216
+ # self.lr_multi = base_lr * bs * acc
217
+ self.lr_multi = base_lr
218
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
219
+ self.lr_warm_up_steps = warm_up_steps
220
+ self.f_start = f_start
221
+ self.f_min = f_min
222
+ self.f_max = f_max
223
+ self.cycle_lengths = cycle_lengths
224
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
225
+ self.last_f = 0.
226
+ self.verbosity_interval = verbosity_interval
227
+
228
+ def find_in_interval(self, n):
229
+ interval = 0
230
+ for cl in self.cum_cycles[1:]:
231
+ if n <= cl:
232
+ return interval
233
+ interval += 1
234
+
235
+ def schedule(self, n):
236
+ cycle = self.find_in_interval(n)
237
+ n = n - self.cum_cycles[cycle]
238
+ if self.verbosity_interval > 0:
239
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
240
+ f"current cycle {cycle}")
241
+ if n < self.lr_warm_up_steps[cycle]:
242
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
243
+ self.last_f = f
244
+ return f
245
+ else:
246
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
247
+ t = min(t, 1.0)
248
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
249
+ 1 + np.cos(t * np.pi))
250
+ self.last_f = f
251
+ return f
252
+
253
+ def __getitem__(self, idx):
254
+ return self.schedule(idx) * self.lr_multi
255
+
256
+
257
+ @register('stable_diffusion_linear')
258
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
259
+ def schedule(self, n):
260
+ cycle = self.find_in_interval(n)
261
+ n = n - self.cum_cycles[cycle]
262
+ if self.verbosity_interval > 0:
263
+ if n % self.verbosity_interval == 0:
264
+ print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
265
+ f"current cycle {cycle}")
266
+ if n < self.lr_warm_up_steps[cycle]:
267
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
268
+ self.last_f = f
269
+ return f
270
+ else:
271
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
272
+ self.last_f = f
273
+ return f
core/models/common/utils.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import functools
6
+ import itertools
7
+
8
+
9
+ ########
10
+ # unit #
11
+ ########
12
+
13
+
14
+ def singleton(class_):
15
+ instances = {}
16
+
17
+ def getinstance(*args, **kwargs):
18
+ if class_ not in instances:
19
+ instances[class_] = class_(*args, **kwargs)
20
+ return instances[class_]
21
+
22
+ return getinstance
23
+
24
+
25
+ def str2value(v):
26
+ v = v.strip()
27
+ try:
28
+ return int(v)
29
+ except:
30
+ pass
31
+ try:
32
+ return float(v)
33
+ except:
34
+ pass
35
+ if v in ('True', 'true'):
36
+ return True
37
+ elif v in ('False', 'false'):
38
+ return False
39
+ else:
40
+ return v
41
+
42
+
43
+ @singleton
44
+ class get_unit(object):
45
+ def __init__(self):
46
+ self.unit = {}
47
+ self.register('none', None)
48
+
49
+ # general convolution
50
+ self.register('conv', nn.Conv2d)
51
+ self.register('bn', nn.BatchNorm2d)
52
+ self.register('relu', nn.ReLU)
53
+ self.register('relu6', nn.ReLU6)
54
+ self.register('lrelu', nn.LeakyReLU)
55
+ self.register('dropout', nn.Dropout)
56
+ self.register('dropout2d', nn.Dropout2d)
57
+ self.register('sine', Sine)
58
+ self.register('relusine', ReLUSine)
59
+
60
+ def register(self,
61
+ name,
62
+ unitf, ):
63
+
64
+ self.unit[name] = unitf
65
+
66
+ def __call__(self, name):
67
+ if name is None:
68
+ return None
69
+ i = name.find('(')
70
+ i = len(name) if i == -1 else i
71
+ t = name[:i]
72
+ f = self.unit[t]
73
+ args = name[i:].strip('()')
74
+ if len(args) == 0:
75
+ args = {}
76
+ return f
77
+ else:
78
+ args = args.split('=')
79
+ args = [[','.join(i.split(',')[:-1]), i.split(',')[-1]] for i in args]
80
+ args = list(itertools.chain.from_iterable(args))
81
+ args = [i.strip() for i in args if len(i) > 0]
82
+ kwargs = {}
83
+ for k, v in zip(args[::2], args[1::2]):
84
+ if v[0] == '(' and v[-1] == ')':
85
+ kwargs[k] = tuple([str2value(i) for i in v.strip('()').split(',')])
86
+ elif v[0] == '[' and v[-1] == ']':
87
+ kwargs[k] = [str2value(i) for i in v.strip('[]').split(',')]
88
+ else:
89
+ kwargs[k] = str2value(v)
90
+ return functools.partial(f, **kwargs)
91
+
92
+
93
+ def register(name):
94
+ def wrapper(class_):
95
+ get_unit().register(name, class_)
96
+ return class_
97
+
98
+ return wrapper
99
+
100
+
101
+ class Sine(object):
102
+ def __init__(self, freq, gain=1):
103
+ self.freq = freq
104
+ self.gain = gain
105
+ self.repr = 'sine(freq={}, gain={})'.format(freq, gain)
106
+
107
+ def __call__(self, x, gain=1):
108
+ act_gain = self.gain * gain
109
+ return torch.sin(self.freq * x) * act_gain
110
+
111
+ def __repr__(self, ):
112
+ return self.repr
113
+
114
+
115
+ class ReLUSine(nn.Module):
116
+ def __init(self):
117
+ super().__init__()
118
+
119
+ def forward(self, input):
120
+ a = torch.sin(30 * input)
121
+ b = nn.ReLU(inplace=False)(input)
122
+ return a + b
123
+
124
+
125
+ @register('lrelu_agc')
126
+ class lrelu_agc(object):
127
+ """
128
+ The lrelu layer with alpha, gain and clamp
129
+ """
130
+
131
+ def __init__(self, alpha=0.1, gain=1, clamp=None):
132
+ # super().__init__()
133
+ self.alpha = alpha
134
+ if gain == 'sqrt_2':
135
+ self.gain = np.sqrt(2)
136
+ else:
137
+ self.gain = gain
138
+ self.clamp = clamp
139
+ self.repr = 'lrelu_agc(alpha={}, gain={}, clamp={})'.format(
140
+ alpha, gain, clamp)
141
+
142
+ # def forward(self, x, gain=1):
143
+ def __call__(self, x, gain=1):
144
+ x = F.leaky_relu(x, negative_slope=self.alpha, inplace=True)
145
+ act_gain = self.gain * gain
146
+ act_clamp = self.clamp * gain if self.clamp is not None else None
147
+ if act_gain != 1:
148
+ x = x * act_gain
149
+ if act_clamp is not None:
150
+ x = x.clamp(-act_clamp, act_clamp)
151
+ return x
152
+
153
+ def __repr__(self, ):
154
+ return self.repr
155
+
156
+
157
+ ####################
158
+ # spatial encoding #
159
+ ####################
160
+
161
+
162
+ @register('se')
163
+ class SpatialEncoding(nn.Module):
164
+ def __init__(self,
165
+ in_dim,
166
+ out_dim,
167
+ sigma=6,
168
+ cat_input=True,
169
+ require_grad=False, ):
170
+
171
+ super().__init__()
172
+ assert out_dim % (2 * in_dim) == 0, "dimension must be dividable"
173
+
174
+ n = out_dim // 2 // in_dim
175
+ m = 2 ** np.linspace(0, sigma, n)
176
+ m = np.stack([m] + [np.zeros_like(m)] * (in_dim - 1), axis=-1)
177
+ m = np.concatenate([np.roll(m, i, axis=-1) for i in range(in_dim)], axis=0)
178
+ self.emb = torch.FloatTensor(m)
179
+ if require_grad:
180
+ self.emb = nn.Parameter(self.emb, requires_grad=True)
181
+ self.in_dim = in_dim
182
+ self.out_dim = out_dim
183
+ self.sigma = sigma
184
+ self.cat_input = cat_input
185
+ self.require_grad = require_grad
186
+
187
+ def forward(self, x, format='[n x c]'):
188
+ """
189
+ Args:
190
+ x: [n x m1],
191
+ m1 usually is 2
192
+ Outputs:
193
+ y: [n x m2]
194
+ m2 dimention number
195
+ :param format:
196
+ """
197
+ if format == '[bs x c x 2D]':
198
+ xshape = x.shape
199
+ x = x.permute(0, 2, 3, 1).contiguous()
200
+ x = x.view(-1, x.size(-1))
201
+ elif format == '[n x c]':
202
+ pass
203
+ else:
204
+ raise ValueError
205
+
206
+ if not self.require_grad:
207
+ self.emb = self.emb.to(x.device)
208
+ y = torch.mm(x, self.emb.T)
209
+ if self.cat_input:
210
+ z = torch.cat([x, torch.sin(y), torch.cos(y)], dim=-1)
211
+ else:
212
+ z = torch.cat([torch.sin(y), torch.cos(y)], dim=-1)
213
+
214
+ if format == '[bs x c x 2D]':
215
+ z = z.view(xshape[0], xshape[2], xshape[3], -1)
216
+ z = z.permute(0, 3, 1, 2).contiguous()
217
+ return z
218
+
219
+ def extra_repr(self):
220
+ outstr = 'SpatialEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
221
+ self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
222
+ return outstr
223
+
224
+
225
+ @register('rffe')
226
+ class RFFEncoding(SpatialEncoding):
227
+ """
228
+ Random Fourier Features
229
+ """
230
+
231
+ def __init__(self,
232
+ in_dim,
233
+ out_dim,
234
+ sigma=6,
235
+ cat_input=True,
236
+ require_grad=False, ):
237
+ super().__init__(in_dim, out_dim, sigma, cat_input, require_grad)
238
+ n = out_dim // 2
239
+ m = np.random.normal(0, sigma, size=(n, in_dim))
240
+ self.emb = torch.FloatTensor(m)
241
+ if require_grad:
242
+ self.emb = nn.Parameter(self.emb, requires_grad=True)
243
+
244
+ def extra_repr(self):
245
+ outstr = 'RFFEncoding (in={}, out={}, sigma={}, cat_input={}, require_grad={})'.format(
246
+ self.in_dim, self.out_dim, self.sigma, self.cat_input, self.require_grad)
247
+ return outstr
248
+
249
+
250
+ ##########
251
+ # helper #
252
+ ##########
253
+
254
+
255
+ def freeze(net):
256
+ for m in net.modules():
257
+ if isinstance(m, (
258
+ nn.BatchNorm2d,
259
+ nn.SyncBatchNorm,)):
260
+ # inplace_abn not supported
261
+ m.eval()
262
+ for pi in net.parameters():
263
+ pi.requires_grad = False
264
+ return net
265
+
266
+
267
+ def common_init(m):
268
+ if isinstance(m, (
269
+ nn.Conv2d,
270
+ nn.ConvTranspose2d,)):
271
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
272
+ if m.bias is not None:
273
+ nn.init.constant_(m.bias, 0)
274
+ elif isinstance(m, (
275
+ nn.BatchNorm2d,
276
+ nn.SyncBatchNorm,)):
277
+ nn.init.constant_(m.weight, 1)
278
+ nn.init.constant_(m.bias, 0)
279
+ else:
280
+ pass
281
+
282
+
283
+ def init_module(module):
284
+ """
285
+ Args:
286
+ module: [nn.module] list or nn.module
287
+ a list of module to be initialized.
288
+ """
289
+ if isinstance(module, (list, tuple)):
290
+ module = list(module)
291
+ else:
292
+ module = [module]
293
+
294
+ for mi in module:
295
+ for mii in mi.modules():
296
+ common_init(mii)
297
+
298
+
299
+ def get_total_param(net):
300
+ if getattr(net, 'parameters', None) is None:
301
+ return 0
302
+ return sum(p.numel() for p in net.parameters())
303
+
304
+
305
+ def get_total_param_sum(net):
306
+ if getattr(net, 'parameters', None) is None:
307
+ return 0
308
+ with torch.no_grad():
309
+ s = sum(p.cpu().detach().numpy().sum().item() for p in net.parameters())
310
+ return s
core/models/dani_model.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision.transforms as tvtrans
7
+
8
+ from einops import rearrange
9
+
10
+ import pytorch_lightning as pl
11
+
12
+ from . import get_model
13
+ from ..cfg_helper import model_cfg_bank
14
+ from ..common.utils import regularize_image, regularize_video, remove_duplicate_word
15
+
16
+ import warnings
17
+
18
+ warnings.filterwarnings("ignore")
19
+
20
+
21
+ class dani_model(pl.LightningModule):
22
+ def __init__(self, model='thesis_model', load_weights=True, data_dir='pretrained', pth=["CoDi_encoders.pth"], fp16=False):
23
+ super().__init__()
24
+ # import torch
25
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
26
+ cfgm = model_cfg_bank()(model)
27
+ net = get_model()(cfgm)
28
+ if load_weights:
29
+ for path in pth:
30
+ net.load_state_dict(torch.load(os.path.join(data_dir, path), map_location='cpu'), strict=False)
31
+ print('Load pretrained weight from {}'.format(pth))
32
+
33
+ self.net = net
34
+
35
+ from core.models.ddim.ddim_vd import DDIMSampler_VD
36
+ self.sampler = DDIMSampler_VD(net)
37
+
38
+ def decode(self, z, xtype):
39
+ device = z.device
40
+ net = self.net
41
+ z = z.to(device)
42
+ if xtype == 'image':
43
+ x = net.autokl_decode(z)
44
+ x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
45
+ return x
46
+
47
+ elif xtype == 'video':
48
+ num_frames = z.shape[2]
49
+ z = rearrange(z, 'b c f h w -> (b f) c h w')
50
+ x = net.autokl_decode(z)
51
+ x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames)
52
+
53
+ x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
54
+ video_list = []
55
+ for video in x:
56
+ video_list.append([tvtrans.ToPILImage()(xi) for xi in video])
57
+ return video_list
58
+
59
+ elif xtype == 'text':
60
+ prompt_temperature = 1.0
61
+ prompt_merge_same_adj_word = True
62
+ x = net.optimus_decode(z, temperature=prompt_temperature)
63
+ """
64
+ if prompt_merge_same_adj_word:
65
+ xnew = []
66
+ for xi in x:
67
+ xi_split = xi.split()
68
+ xinew = []
69
+ for idxi, wi in enumerate(xi_split):
70
+ if idxi!=0 and wi==xi_split[idxi-1]:
71
+ continue
72
+ xinew.append(wi)
73
+ xnew.append(remove_duplicate_word(' '.join(xinew)))
74
+ x = xnew
75
+ """
76
+ return x
77
+
78
+ elif xtype == 'audio':
79
+ x = net.audioldm_decode(z)
80
+ x = net.mel_spectrogram_to_waveform(x)
81
+ return x
82
+
83
+ def forward(self, xtype=[], condition=[], condition_types=[], n_samples=1,
84
+ mix_weight={'video': 1, 'audio': 1, 'text': 1, 'image': 1}, image_size=256, ddim_steps=50, scale=7.5,
85
+ num_frames=8):
86
+ # import torch
87
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
88
+ device = self.device
89
+ net = self.net
90
+ sampler = self.sampler
91
+ ddim_eta = 0.0
92
+
93
+ conditioning = []
94
+ assert len(set(condition_types)) == len(condition_types), "we don't support condition with same modalities yet."
95
+ assert len(condition) == len(condition_types)
96
+
97
+ for i, condition_type in enumerate(condition_types):
98
+ if condition_type == 'image':
99
+ print(condition[i].shape)
100
+ ctemp1 = regularize_image(condition[i]).squeeze().to(device)
101
+ print(ctemp1.shape)
102
+ ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1)
103
+ cim = net.clip_encode_vision(ctemp1).to(device)
104
+ uim = None
105
+ if scale != 1.0:
106
+ dummy = torch.zeros_like(ctemp1).to(device)
107
+ uim = net.clip_encode_vision(dummy).to(device)
108
+ conditioning.append(torch.cat([uim, cim]))
109
+
110
+ elif condition_type == 'video':
111
+ ctemp1 = regularize_video(condition[i]).to(device)
112
+ ctemp1 = ctemp1[None].repeat(n_samples, 1, 1, 1, 1)
113
+ cim = net.clip_encode_vision(ctemp1).to(device)
114
+ uim = None
115
+ if scale != 1.0:
116
+ dummy = torch.zeros_like(ctemp1).to(device)
117
+ uim = net.clip_encode_vision(dummy).to(device)
118
+ conditioning.append(torch.cat([uim, cim]))
119
+
120
+ elif condition_type == 'audio':
121
+ ctemp = condition[i][None].repeat(n_samples, 1, 1)
122
+ cad = net.clap_encode_audio(ctemp)
123
+ uad = None
124
+ if scale != 1.0:
125
+ dummy = torch.zeros_like(ctemp)
126
+ uad = net.clap_encode_audio(dummy)
127
+ conditioning.append(torch.cat([uad, cad]))
128
+
129
+ elif condition_type == 'text':
130
+ ctx = net.clip_encode_text(n_samples * [condition[i]]).to(device)
131
+ utx = None
132
+ if scale != 1.0:
133
+ utx = net.clip_encode_text(n_samples * [""]).to(device)
134
+ conditioning.append(torch.cat([utx, ctx]))
135
+
136
+ shapes = []
137
+ for xtype_i in xtype:
138
+ if xtype_i == 'image':
139
+ h, w = [image_size, image_size]
140
+ shape = [n_samples, 4, h // 8, w // 8]
141
+ elif xtype_i == 'video':
142
+ h, w = [image_size, image_size]
143
+ shape = [n_samples, 4, num_frames, h // 8, w // 8]
144
+ elif xtype_i == 'text':
145
+ n = 768
146
+ shape = [n_samples, n]
147
+ elif xtype_i == 'audio':
148
+ h, w = [256, 16]
149
+ shape = [n_samples, 8, h, w]
150
+ else:
151
+ raise
152
+ shapes.append(shape)
153
+
154
+ z, _ = sampler.sample(
155
+ steps=ddim_steps,
156
+ shape=shapes,
157
+ condition=conditioning,
158
+ unconditional_guidance_scale=scale,
159
+ xtype=xtype,
160
+ condition_types=condition_types,
161
+ eta=ddim_eta,
162
+ verbose=False,
163
+ mix_weight=mix_weight)
164
+
165
+ out_all = []
166
+ for i, xtype_i in enumerate(xtype):
167
+ z[i] = z[i].to(device)
168
+ x_i = self.decode(z[i], xtype_i)
169
+ out_all.append(x_i)
170
+ return out_all
core/models/ddim/__pycache__/ddim.cpython-38.pyc ADDED
Binary file (6.27 kB). View file
 
core/models/ddim/__pycache__/ddim_vd.cpython-38.pyc ADDED
Binary file (4.29 kB). View file
 
core/models/ddim/__pycache__/diffusion_utils.cpython-38.pyc ADDED
Binary file (9.56 kB). View file
 
core/models/ddim/ddim.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+
11
+ class DDIMSampler(object):
12
+ def __init__(self, model, schedule="linear", **kwargs):
13
+ super().__init__()
14
+ self.model = model
15
+ self.ddpm_num_timesteps = model.num_timesteps
16
+ self.schedule = schedule
17
+
18
+ def register_buffer(self, name, attr):
19
+ # import torch
20
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
21
+ device = self.model.device
22
+
23
+ if type(attr) == torch.Tensor:
24
+ if attr.device != device:
25
+ attr = attr.to(device)
26
+ setattr(self, name, attr)
27
+
28
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
29
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize,
30
+ num_ddim_timesteps=ddim_num_steps,
31
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
32
+ verbose=verbose)
33
+ alphas_cumprod = self.model.alphas_cumprod
34
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
35
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
36
+
37
+ self.register_buffer('betas', to_torch(self.model.betas))
38
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
39
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
40
+
41
+ # calculations for diffusion q(x_t | x_{t-1}) and others
42
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
43
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
44
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
45
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
46
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
47
+
48
+ # ddim sampling parameters
49
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
50
+ alphacums=alphas_cumprod.cpu(),
51
+ ddim_timesteps=self.ddim_timesteps,
52
+ eta=ddim_eta,verbose=verbose)
53
+
54
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
55
+ self.register_buffer('ddim_alphas', ddim_alphas)
56
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
57
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
58
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
59
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
60
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
61
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
62
+
63
+ @torch.no_grad()
64
+ def sample(self,
65
+ S,
66
+ batch_size,
67
+ shape,
68
+ conditioning=None,
69
+ callback=None,
70
+ normals_sequence=None,
71
+ img_callback=None,
72
+ quantize_x0=False,
73
+ eta=0.,
74
+ mask=None,
75
+ x0=None,
76
+ temperature=1.,
77
+ noise_dropout=0.,
78
+ score_corrector=None,
79
+ corrector_kwargs=None,
80
+ verbose=True,
81
+ x_T=None,
82
+ log_every_t=100,
83
+ unconditional_guidance_scale=1.,
84
+ unconditional_conditioning=None,
85
+ video_frame_share_noise=False,
86
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
87
+ **kwargs
88
+ ):
89
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
90
+ device = self.model.device
91
+
92
+ if conditioning is not None:
93
+ if isinstance(conditioning, dict):
94
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
95
+ if cbs != batch_size:
96
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
97
+ else:
98
+ if conditioning.shape[0] != batch_size:
99
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
100
+
101
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
102
+ # sampling
103
+ C, H, W = shape
104
+ size = (batch_size, C, H, W)
105
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
106
+
107
+ samples, intermediates = self.ddim_sampling(conditioning, size,
108
+ callback=callback,
109
+ img_callback=img_callback,
110
+ quantize_denoised=quantize_x0,
111
+ mask=mask, x0=x0,
112
+ ddim_use_original_steps=False,
113
+ noise_dropout=noise_dropout,
114
+ temperature=temperature,
115
+ score_corrector=score_corrector,
116
+ corrector_kwargs=corrector_kwargs,
117
+ x_T=x_T,
118
+ log_every_t=log_every_t,
119
+ unconditional_guidance_scale=unconditional_guidance_scale,
120
+ unconditional_conditioning=unconditional_conditioning,
121
+ )
122
+ return samples, intermediates
123
+
124
+ @torch.no_grad()
125
+ def ddim_sampling(self,
126
+ cond, shape,
127
+ x_T=None,
128
+ ddim_use_original_steps=False,
129
+ callback=None,
130
+ timesteps=None,
131
+ quantize_denoised=False,
132
+ mask=None, x0=None,
133
+ img_callback=None, log_every_t=100,
134
+ temperature=1.,
135
+ noise_dropout=0.,
136
+ score_corrector=None,
137
+ corrector_kwargs=None,
138
+ unconditional_guidance_scale=1.,
139
+ unconditional_conditioning=None,):
140
+ device = self.model.betas.device
141
+ b = shape[0]
142
+ if x_T is None:
143
+ img = torch.randn(shape, device=device)
144
+ else:
145
+ img = x_T
146
+
147
+ if timesteps is None:
148
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
149
+ elif timesteps is not None and not ddim_use_original_steps:
150
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
151
+ timesteps = self.ddim_timesteps[:subset_end]
152
+
153
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
154
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
155
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
156
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
157
+
158
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
159
+
160
+ for i, step in enumerate(iterator):
161
+ index = total_steps - i - 1
162
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
163
+
164
+ if mask is not None:
165
+ assert x0 is not None
166
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
167
+ img = img_orig * mask + (1. - mask) * img
168
+
169
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
170
+ quantize_denoised=quantize_denoised, temperature=temperature,
171
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
172
+ corrector_kwargs=corrector_kwargs,
173
+ unconditional_guidance_scale=unconditional_guidance_scale,
174
+ unconditional_conditioning=unconditional_conditioning)
175
+ img, pred_x0 = outs
176
+ if callback: callback(i)
177
+ if img_callback: img_callback(pred_x0, i)
178
+
179
+ if index % log_every_t == 0 or index == total_steps - 1:
180
+ intermediates['x_inter'].append(img)
181
+ intermediates['pred_x0'].append(pred_x0)
182
+
183
+ return img, intermediates
184
+
185
+ @torch.no_grad()
186
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
187
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
188
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
189
+ b, *_, device = *x.shape, x.device
190
+
191
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
192
+ e_t = self.model.apply_model(x, t, c)
193
+ else:
194
+ x_in = torch.cat([x] * 2)
195
+ t_in = torch.cat([t] * 2)
196
+ c_in = torch.cat([unconditional_conditioning, c])
197
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
198
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
199
+
200
+ if score_corrector is not None:
201
+ assert self.model.parameterization == "eps"
202
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
203
+
204
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
205
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
206
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
207
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
208
+ # select parameters corresponding to the currently considered timestep
209
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
210
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
211
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
212
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
213
+
214
+ # current prediction for x_0
215
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
216
+ if quantize_denoised:
217
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
218
+ # direction pointing to x_t
219
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
220
+ noise = sigma_t * noise_like(x, repeat_noise) * temperature
221
+ if noise_dropout > 0.:
222
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
223
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
224
+ return x_prev, pred_x0
core/models/ddim/ddim_vd.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/SHI-Labs/Versatile-Diffusion
3
+ """
4
+
5
+ import torch
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from functools import partial
9
+
10
+ from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
11
+
12
+ from .ddim import DDIMSampler
13
+
14
+
15
+ class DDIMSampler_VD(DDIMSampler):
16
+ @torch.no_grad()
17
+ def sample(self,
18
+ steps,
19
+ shape,
20
+ xt=None,
21
+ condition=None,
22
+ unconditional_guidance_scale=1.,
23
+ xtype='image',
24
+ condition_types=['text'],
25
+ eta=0.,
26
+ temperature=1.,
27
+ mix_weight=None,
28
+ noise_dropout=0.,
29
+ verbose=True,
30
+ log_every_t=100, ):
31
+
32
+ self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
33
+ print(f'Data shape for DDIM sampling is {shape}, eta {eta}')
34
+ samples, intermediates = self.ddim_sampling(
35
+ shape,
36
+ xt=xt,
37
+ condition=condition,
38
+ unconditional_guidance_scale=unconditional_guidance_scale,
39
+ xtype=xtype,
40
+ condition_types=condition_types,
41
+ ddim_use_original_steps=False,
42
+ noise_dropout=noise_dropout,
43
+ temperature=temperature,
44
+ log_every_t=log_every_t,
45
+ mix_weight=mix_weight, )
46
+ return samples, intermediates
47
+
48
+ @torch.no_grad()
49
+ def ddim_sampling(self,
50
+ shape,
51
+ xt=None,
52
+ condition=None,
53
+ unconditional_guidance_scale=1.,
54
+ xtype=['image'],
55
+ condition_types=['text'],
56
+ ddim_use_original_steps=False,
57
+ timesteps=None,
58
+ noise_dropout=0.,
59
+ temperature=1.,
60
+ mix_weight=None,
61
+ log_every_t=100, ):
62
+
63
+ device = self.model.device
64
+ dtype = condition[0][0].dtype
65
+
66
+ if isinstance(shape[0], list):
67
+ bs = shape[0][0]
68
+ else:
69
+ bs = shape[0]
70
+ if xt is None:
71
+ if isinstance(shape[0], list):
72
+ xt = [torch.randn(shape_i, device=device, dtype=dtype) for shape_i in shape]
73
+ else:
74
+ xt = torch.randn(shape, device=device, dtype=dtype)
75
+
76
+ if timesteps is None:
77
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
78
+ elif timesteps is not None and not ddim_use_original_steps:
79
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
80
+ timesteps = self.ddim_timesteps[:subset_end]
81
+
82
+ intermediates = {'pred_xt': [], 'pred_x0': []}
83
+ time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
84
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
85
+ # print(f"Running DDIM Sampling with {total_steps} timesteps")
86
+
87
+ pred_xt = xt
88
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
89
+ for i, step in enumerate(iterator):
90
+ index = total_steps - i - 1
91
+ ts = torch.full((bs,), step, device=device, dtype=torch.long)
92
+
93
+ outs = self.p_sample_ddim(
94
+ pred_xt,
95
+ condition,
96
+ ts, index,
97
+ unconditional_guidance_scale=unconditional_guidance_scale,
98
+ xtype=xtype,
99
+ condition_types=condition_types,
100
+ use_original_steps=ddim_use_original_steps,
101
+ noise_dropout=noise_dropout,
102
+ temperature=temperature,
103
+ mix_weight=mix_weight, )
104
+ pred_xt, pred_x0 = outs
105
+
106
+ if index % log_every_t == 0 or index == total_steps - 1:
107
+ intermediates['pred_xt'].append(pred_xt)
108
+ intermediates['pred_x0'].append(pred_x0)
109
+
110
+ return pred_xt, intermediates
111
+
112
+ @torch.no_grad()
113
+ def p_sample_ddim(self, x,
114
+ condition,
115
+ t, index,
116
+ unconditional_guidance_scale=1.,
117
+ xtype=['image'],
118
+ condition_types=['text'],
119
+ repeat_noise=False,
120
+ use_original_steps=False,
121
+ noise_dropout=0.,
122
+ temperature=1.,
123
+ mix_weight=None, ):
124
+
125
+ b, *_, device = *x[0].shape, x[0].device
126
+
127
+ x_in = []
128
+ for x_i in x:
129
+ x_in.append(torch.cat([x_i] * 2))
130
+ t_in = torch.cat([t] * 2)
131
+
132
+ out = self.model.model.diffusion_model(
133
+ x_in, t_in, condition, xtype=xtype, condition_types=condition_types, mix_weight=mix_weight)
134
+ e_t = []
135
+ for out_i in out:
136
+ e_t_uncond_i, e_t_i = out_i.chunk(2)
137
+ e_t_i = e_t_uncond_i + unconditional_guidance_scale * (e_t_i - e_t_uncond_i)
138
+ e_t_i = e_t_i.to(device)
139
+ e_t.append(e_t_i)
140
+
141
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
142
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
143
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
144
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
145
+ # select parameters corresponding to the currently considered timestep
146
+
147
+ x_prev = []
148
+ pred_x0 = []
149
+ device = x[0].device
150
+ dtype = x[0].dtype
151
+ for i, xtype_i in enumerate(xtype):
152
+ if xtype_i in ['image', 'frontal', 'lateral']:
153
+ extended_shape = (b, 1, 1, 1)
154
+ elif xtype_i == 'video':
155
+ extended_shape = (b, 1, 1, 1, 1)
156
+ elif xtype_i == 'text':
157
+ extended_shape = (b, 1)
158
+ elif xtype_i == 'audio':
159
+ extended_shape = (b, 1, 1, 1)
160
+
161
+ a_t = torch.full(extended_shape, alphas[index], device=device, dtype=dtype)
162
+ a_prev = torch.full(extended_shape, alphas_prev[index], device=device, dtype=dtype)
163
+ sigma_t = torch.full(extended_shape, sigmas[index], device=device, dtype=dtype)
164
+ sqrt_one_minus_at = torch.full(extended_shape, sqrt_one_minus_alphas[index], device=device, dtype=dtype)
165
+
166
+ # current prediction for x_0
167
+ pred_x0_i = (x[i] - sqrt_one_minus_at * e_t[i]) / a_t.sqrt()
168
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t[i]
169
+ noise = sigma_t * noise_like(x[i], repeat_noise) * temperature
170
+ if noise_dropout > 0.:
171
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
172
+ x_prev_i = a_prev.sqrt() * pred_x0_i + dir_xt + noise
173
+ x_prev.append(x_prev_i)
174
+ pred_x0.append(pred_x0_i)
175
+ return x_prev, pred_x0
core/models/ddim/diffusion_utils.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import repeat
7
+
8
+
9
+ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
10
+ if schedule == "linear":
11
+ betas = (
12
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
13
+ )
14
+
15
+ elif schedule == "cosine":
16
+ timesteps = (
17
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
18
+ )
19
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
20
+ alphas = torch.cos(alphas).pow(2)
21
+ alphas = alphas / alphas[0]
22
+ betas = 1 - alphas[1:] / alphas[:-1]
23
+ betas = np.clip(betas, a_min=0, a_max=0.999)
24
+
25
+ elif schedule == "sqrt_linear":
26
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
27
+ elif schedule == "sqrt":
28
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
29
+ else:
30
+ raise ValueError(f"schedule '{schedule}' unknown.")
31
+ return betas.numpy()
32
+
33
+
34
+ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
35
+ if ddim_discr_method == 'uniform':
36
+ c = num_ddpm_timesteps // num_ddim_timesteps
37
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
38
+ elif ddim_discr_method == 'quad':
39
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
40
+ else:
41
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
42
+
43
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
44
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
45
+ if num_ddpm_timesteps != 1000:
46
+ steps_out = ddim_timesteps + 1
47
+ else:
48
+ steps_out = ddim_timesteps
49
+ if verbose:
50
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
51
+ return steps_out
52
+
53
+
54
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
55
+ # select alphas for computing the variance schedule
56
+ alphas = alphacums[ddim_timesteps]
57
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
58
+
59
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
60
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
61
+ if verbose:
62
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
63
+ print(f'For the chosen value of eta, which is {eta}, '
64
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
65
+ return sigmas, alphas, alphas_prev
66
+
67
+
68
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
69
+ """
70
+ Create a beta schedule that discretizes the given alpha_t_bar function,
71
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
72
+ :param num_diffusion_timesteps: the number of betas to produce.
73
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
74
+ produces the cumulative product of (1-beta) up to that
75
+ part of the diffusion process.
76
+ :param max_beta: the maximum beta to use; use values lower than 1 to
77
+ prevent singularities.
78
+ """
79
+ betas = []
80
+ for i in range(num_diffusion_timesteps):
81
+ t1 = i / num_diffusion_timesteps
82
+ t2 = (i + 1) / num_diffusion_timesteps
83
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
84
+ return np.array(betas)
85
+
86
+
87
+ def extract_into_tensor(a, t, x_shape):
88
+ b, *_ = t.shape
89
+ out = a.gather(-1, t)
90
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
91
+
92
+
93
+ def checkpoint(func, inputs, params, flag):
94
+ """
95
+ Evaluate a function without caching intermediate activations, allowing for
96
+ reduced memory at the expense of extra compute in the backward pass.
97
+ :param func: the function to evaluate.
98
+ :param inputs: the argument sequence to pass to `func`.
99
+ :param params: a sequence of parameters `func` depends on but does not
100
+ explicitly take as arguments.
101
+ :param flag: if False, disable gradient checkpointing.
102
+ """
103
+ if flag:
104
+ args = tuple(inputs) + tuple(params)
105
+ return CheckpointFunction.apply(func, len(inputs), *args)
106
+ else:
107
+ return func(*inputs)
108
+
109
+
110
+ class CheckpointFunction(torch.autograd.Function):
111
+ @staticmethod
112
+ def forward(ctx, run_function, length, *args):
113
+ ctx.run_function = run_function
114
+ ctx.input_tensors = list(args[:length])
115
+ ctx.input_params = list(args[length:])
116
+
117
+ with torch.no_grad():
118
+ output_tensors = ctx.run_function(*ctx.input_tensors)
119
+ return output_tensors
120
+
121
+ @staticmethod
122
+ def backward(ctx, *output_grads):
123
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
124
+ with torch.enable_grad():
125
+ # Fixes a bug where the first op in run_function modifies the
126
+ # Tensor storage in place, which is not allowed for detach()'d
127
+ # Tensors.
128
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
129
+ output_tensors = ctx.run_function(*shallow_copies)
130
+ input_grads = torch.autograd.grad(
131
+ output_tensors,
132
+ ctx.input_tensors + ctx.input_params,
133
+ output_grads,
134
+ allow_unused=True,
135
+ )
136
+ del ctx.input_tensors
137
+ del ctx.input_params
138
+ del output_tensors
139
+ return (None, None) + input_grads
140
+
141
+
142
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
143
+ """
144
+ Create sinusoidal timestep embeddings.
145
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
146
+ These may be fractional.
147
+ :param dim: the dimension of the output.
148
+ :param max_period: controls the minimum frequency of the embeddings.
149
+ :return: an [N x dim] Tensor of positional embeddings.
150
+ """
151
+ if not repeat_only:
152
+ half = dim // 2
153
+ freqs = torch.exp(
154
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
155
+ ).to(device=timesteps.device)
156
+ args = timesteps[:, None].float() * freqs[None]
157
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
158
+ if dim % 2:
159
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
160
+ else:
161
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
162
+ return embedding
163
+
164
+
165
+ def zero_module(module):
166
+ """
167
+ Zero out the parameters of a module and return it.
168
+ """
169
+ for p in module.parameters():
170
+ p.detach().zero_()
171
+ return module
172
+
173
+
174
+ def scale_module(module, scale):
175
+ """
176
+ Scale the parameters of a module and return it.
177
+ """
178
+ for p in module.parameters():
179
+ p.detach().mul_(scale)
180
+ return module
181
+
182
+
183
+ def mean_flat(tensor):
184
+ """
185
+ Take the mean over all non-batch dimensions.
186
+ """
187
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
188
+
189
+
190
+ def normalization(channels):
191
+ """
192
+ Make a standard normalization layer.
193
+ :param channels: number of input channels.
194
+ :return: an nn.Module for normalization.
195
+ """
196
+ return GroupNorm32(32, channels)
197
+
198
+
199
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
200
+ class SiLU(nn.Module):
201
+ def forward(self, x):
202
+ return x * torch.sigmoid(x)
203
+
204
+
205
+ class GroupNorm32(nn.GroupNorm):
206
+ def forward(self, x):
207
+ # return super().forward(x.float()).type(x.dtype)
208
+ return super().forward(x)
209
+
210
+
211
+ def conv_nd(dims, *args, **kwargs):
212
+ """
213
+ Create a 1D, 2D, or 3D convolution module.
214
+ """
215
+ if dims == 1:
216
+ return nn.Conv1d(*args, **kwargs)
217
+ elif dims == 2:
218
+ return nn.Conv2d(*args, **kwargs)
219
+ elif dims == 3:
220
+ return nn.Conv3d(*args, **kwargs)
221
+ raise ValueError(f"unsupported dimensions: {dims}")
222
+
223
+
224
+ def linear(*args, **kwargs):
225
+ """
226
+ Create a linear module.
227
+ """
228
+ return nn.Linear(*args, **kwargs)
229
+
230
+
231
+ def avg_pool_nd(dims, *args, **kwargs):
232
+ """
233
+ Create a 1D, 2D, or 3D average pooling module.
234
+ """
235
+ if dims == 1:
236
+ return nn.AvgPool1d(*args, **kwargs)
237
+ elif dims == 2:
238
+ return nn.AvgPool2d(*args, **kwargs)
239
+ elif dims == 3:
240
+ return nn.AvgPool3d(*args, **kwargs)
241
+ raise ValueError(f"unsupported dimensions: {dims}")
242
+
243
+
244
+ class HybridConditioner(nn.Module):
245
+
246
+ def __init__(self, c_concat_config, c_crossattn_config):
247
+ super().__init__()
248
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
249
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
250
+
251
+ def forward(self, c_concat, c_crossattn):
252
+ c_concat = self.concat_conditioner(c_concat)
253
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
254
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
255
+
256
+
257
+ def noise_like(x, repeat=False):
258
+ noise = torch.randn_like(x)
259
+ if repeat:
260
+ bs = x.shape[0]
261
+ noise = noise[0:1].repeat(bs, *((1,) * (len(x.shape) - 1)))
262
+ return noise
263
+
264
+ ##########################
265
+ # inherit from ldm.utils #
266
+ ##########################
267
+
268
+
269
+ def count_params(model, verbose=False):
270
+ total_params = sum(p.numel() for p in model.parameters())
271
+ if verbose:
272
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
273
+ return total_params
core/models/ema.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_updates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError('Decay must be between 0 and 1')
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_updates
14
+ else torch.tensor(-1, dtype=torch.int))
15
+
16
+ for name, p in model.named_parameters():
17
+ if p.requires_grad:
18
+ # remove as '.'-character is not allowed in buffers
19
+ s_name = name.replace('.', '')
20
+ self.m_name2s_name.update({name: s_name})
21
+ self.register_buffer(s_name, p.clone().detach().data)
22
+
23
+ self.collected_params = []
24
+
25
+ def forward(self, model):
26
+ decay = self.decay
27
+
28
+ if self.num_updates >= 0:
29
+ self.num_updates += 1
30
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
31
+
32
+ one_minus_decay = 1.0 - decay
33
+
34
+ with torch.no_grad():
35
+ m_param = dict(model.named_parameters())
36
+ shadow_params = dict(self.named_buffers())
37
+
38
+ for key in m_param:
39
+ if m_param[key].requires_grad:
40
+ sname = self.m_name2s_name[key]
41
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43
+ else:
44
+ assert not key in self.m_name2s_name
45
+
46
+ def copy_to(self, model):
47
+ m_param = dict(model.named_parameters())
48
+ shadow_params = dict(self.named_buffers())
49
+ for key in m_param:
50
+ if m_param[key].requires_grad:
51
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52
+ else:
53
+ assert not key in self.m_name2s_name
54
+
55
+ def store(self, parameters):
56
+ """
57
+ Save the current parameters for restoring later.
58
+ Args:
59
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60
+ temporarily stored.
61
+ """
62
+ self.collected_params = [param.clone() for param in parameters]
63
+
64
+ def restore(self, parameters):
65
+ """
66
+ Restore the parameters stored with the `store` method.
67
+ Useful to validate the model with EMA parameters without affecting the
68
+ original optimization process. Store the parameters before the
69
+ `copy_to` method. After validation (or model saving), use this to
70
+ restore the former parameters.
71
+ Args:
72
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73
+ updated with the stored parameters.
74
+ """
75
+ for c_param, param in zip(self.collected_params, parameters):
76
+ param.data.copy_(c_param.data)
core/models/encoders/__pycache__/clap.cpython-311.pyc ADDED
Binary file (7.09 kB). View file
 
core/models/encoders/__pycache__/clap.cpython-38.pyc ADDED
Binary file (4.16 kB). View file
 
core/models/encoders/__pycache__/clip.cpython-311.pyc ADDED
Binary file (10.4 kB). View file
 
core/models/encoders/__pycache__/clip.cpython-38.pyc ADDED
Binary file (6 kB). View file
 
core/models/encoders/clap.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchaudio
4
+
5
+ from .clap_modules.open_clip import create_model
6
+ from .clap_modules.training.data import get_audio_features
7
+
8
+ from ..common.get_model import register
9
+
10
+
11
+ @register('clap_audio')
12
+ class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
13
+ """Uses the CLAP audio encoder"""
14
+ def __init__(
15
+ self,
16
+ pretrained_path="",
17
+ key="waveform",
18
+ sampling_rate=16000,
19
+ embed_mode="audio",
20
+ unconditional_prob=0.1,
21
+ random_mute=False,
22
+ max_random_mute_portion=0.5,
23
+ training_mode=True,
24
+ joint_embed_shape=768,
25
+ embed_shape=512,
26
+ num_layers=12,
27
+ depths=[2, 2, 6, 2],
28
+ amodel="HTSAT-large",
29
+ ):
30
+ super().__init__()
31
+
32
+ self.key = key
33
+ self.amodel = amodel # or 'PANN-14'
34
+ self.tmodel = "roberta" # the best text encoder in our training
35
+ self.enable_fusion = False # False if you do not want to use the fusion model
36
+ self.fusion_type = "aff_2d"
37
+ self.pretrained = pretrained_path
38
+ self.embed_mode = embed_mode
39
+ self.embed_mode_orig = embed_mode
40
+ self.sampling_rate = sampling_rate
41
+ self.unconditional_prob = unconditional_prob
42
+ self.random_mute = random_mute
43
+ self.joint_embed_shape = joint_embed_shape
44
+ self.max_random_mute_portion = max_random_mute_portion
45
+ self.training_mode = training_mode
46
+ self.model, self.model_cfg = create_model(
47
+ self.amodel,
48
+ self.tmodel,
49
+ self.pretrained,
50
+ precision="fp32",
51
+ device="cpu",
52
+ enable_fusion=self.enable_fusion,
53
+ fusion_type=self.fusion_type,
54
+ joint_embed_shape=self.joint_embed_shape,
55
+ )
56
+
57
+ def get_dtype(self):
58
+ return next(self.model.parameters()).dtype
59
+
60
+ def get_unconditional_condition(self, batchsize):
61
+ self.unconditional_token = self.model.get_text_embedding(
62
+ self.tokenizer(["", ""])
63
+ )[0:1]
64
+ return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
65
+
66
+ def batch_to_list(self, batch):
67
+ ret = []
68
+ for i in range(batch.size(0)):
69
+ ret.append(batch[i])
70
+ return ret
71
+
72
+ def make_decision(self, probability):
73
+ if float(torch.rand(1)) < probability:
74
+ return True
75
+ else:
76
+ return False
77
+
78
+ def random_uniform(self, start, end):
79
+ val = torch.rand(1).item()
80
+ return start + (end - start) * val
81
+
82
+ def _random_mute(self, waveform):
83
+ # waveform: [bs, t-steps]
84
+ t_steps = waveform.size(-1)
85
+ for i in range(waveform.size(0)):
86
+ mute_size = int(
87
+ self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
88
+ )
89
+ mute_start = int(self.random_uniform(0, t_steps - mute_size))
90
+ waveform[i, mute_start : mute_start + mute_size] = 0
91
+ return waveform
92
+
93
+ def cos_similarity(self, waveform, text):
94
+ # waveform: [bs, t_steps]
95
+ with torch.no_grad():
96
+ self.embed_mode = "audio"
97
+ audio_emb = self(waveform.cuda())
98
+ self.embed_mode = "text"
99
+ text_emb = self(text)
100
+ similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
101
+ return similarity.squeeze()
102
+
103
+ def forward(self, batch, key=None):
104
+
105
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
106
+ if self.embed_mode == "audio":
107
+ audio_dict_list = []
108
+ assert (
109
+ self.sampling_rate == 16000
110
+ ), "We only support 16000 sampling rate"
111
+ # batch: [bs, 1, t-samples]
112
+ batch = torchaudio.functional.resample(
113
+ batch, orig_freq=self.sampling_rate, new_freq=48000
114
+ )
115
+
116
+ for waveform in self.batch_to_list(batch):
117
+ audio_dict = {}
118
+ audio_dict = get_audio_features(
119
+ audio_dict,
120
+ waveform.squeeze(),
121
+ 480000,
122
+ data_truncating="fusion",
123
+ data_filling="repeatpad",
124
+ audio_cfg=self.model_cfg["audio_cfg"],
125
+ dtype=self.get_dtype(),
126
+ )
127
+ audio_dict_list.append(audio_dict)
128
+ # [bs, 768]
129
+ embed = self.model.get_audio_embedding(audio_dict_list)
130
+
131
+ embed = embed.unsqueeze(1)
132
+
133
+ # [bs, 1, 768]
134
+ return embed