Spaces:
Sleeping
Sleeping
Upload 340 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- app.py +44 -0
- audiocaps_test_16000_struct.tsv +0 -0
- configs/audiolcm.yaml +130 -0
- configs/autoencoder1d.yaml +74 -0
- configs/teacher.yaml +121 -0
- infer.sh +4 -0
- infer_api.sh +4 -0
- ldm/__pycache__/lr_scheduler.cpython-37.pyc +0 -0
- ldm/__pycache__/lr_scheduler.cpython-38.pyc +0 -0
- ldm/__pycache__/util.cpython-310.pyc +0 -0
- ldm/__pycache__/util.cpython-37.pyc +0 -0
- ldm/__pycache__/util.cpython-38.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_624.cpython-38.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_anylen.cpython-37.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_anylen.cpython-38.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_struct.cpython-38.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_struct_anylen.cpython-38.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-37.pyc +0 -0
- ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-38.pyc +0 -0
- ldm/data/__pycache__/tsvdataset.cpython-38.pyc +0 -0
- ldm/data/joinaudiodataset_624.py +93 -0
- ldm/data/joinaudiodataset_anylen.py +331 -0
- ldm/data/joinaudiodataset_struct.py +95 -0
- ldm/data/joinaudiodataset_struct_anylen.py +336 -0
- ldm/data/joinaudiodataset_struct_sample.py +103 -0
- ldm/data/joinaudiodataset_struct_sample_anylen.py +230 -0
- ldm/data/preprocess/NAT_mel.py +131 -0
- ldm/data/preprocess/__pycache__/NAT_mel.cpython-38.pyc +0 -0
- ldm/data/preprocess/__pycache__/NAT_mel.cpython-39.pyc +0 -0
- ldm/data/preprocess/add_duration.py +45 -0
- ldm/data/preprocess/mel_spec.py +201 -0
- ldm/data/test.py +224 -0
- ldm/data/tsv_dirs/full_data/V1_new/audiocaps_train_16000.tsv +0 -0
- ldm/data/tsv_dirs/full_data/V2/MACS.tsv +0 -0
- ldm/data/tsv_dirs/full_data/V2/WavText5K.tsv +0 -0
- ldm/data/tsv_dirs/full_data/V2/adobe.tsv +0 -0
- ldm/data/tsv_dirs/full_data/V2/audiostock.tsv +0 -0
- ldm/data/tsv_dirs/full_data/V2/epidemic_sound.tsv +3 -0
- ldm/data/tsv_dirs/full_data/caps_struct/audiocaps_train_16000_struct2.tsv +0 -0
- ldm/data/tsv_dirs/full_data/clotho.tsv +0 -0
- ldm/data/tsvdataset.py +67 -0
- ldm/lr_scheduler.py +98 -0
- ldm/models/__pycache__/autoencoder.cpython-37.pyc +0 -0
- ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
- ldm/models/__pycache__/autoencoder.cpython-39.pyc +0 -0
- ldm/models/__pycache__/autoencoder1d.cpython-37.pyc +0 -0
- ldm/models/__pycache__/autoencoder1d.cpython-38.pyc +0 -0
- ldm/models/__pycache__/autoencoder_multi.cpython-38.pyc +0 -0
- ldm/models/autoencoder.py +504 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
ldm/data/tsv_dirs/full_data/V2/epidemic_sound.tsv filter=lfs diff=lfs merge=lfs -text
|
37 |
+
vocoder/BigVGAN/LibriTTS/train-full.txt filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio
|
2 |
+
|
3 |
+
def infer(prompt):
|
4 |
+
config = OmegaConf.load("configs/audiolcm.yaml")
|
5 |
+
|
6 |
+
# print("-------quick debug no load ckpt---------")
|
7 |
+
# model = instantiate_from_config(config['model'])# for quick debug
|
8 |
+
model = load_model_from_config(config,
|
9 |
+
"../logs/2024-04-21T14-50-11_text2music-audioset-nonoverlap/epoch=000184.ckpt")
|
10 |
+
|
11 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
12 |
+
model = model.to(device)
|
13 |
+
|
14 |
+
sampler = LCMSampler(model)
|
15 |
+
|
16 |
+
os.makedirs("results/test", exist_ok=True)
|
17 |
+
|
18 |
+
vocoder = VocoderBigVGAN("../vocoder/logs/bigvnat16k93.5w", device)
|
19 |
+
|
20 |
+
generator = GenSamples(sampler, model, "results/test", vocoder, save_mel=False, save_wav=True,
|
21 |
+
original_inference_steps=config.model.params.num_ddim_timesteps)
|
22 |
+
csv_dicts = []
|
23 |
+
|
24 |
+
with torch.no_grad():
|
25 |
+
with model.ema_scope():
|
26 |
+
wav_name = f'{prompt.strip().replace(" ", "-")}'
|
27 |
+
generator.gen_test_sample(prompt, wav_name=wav_name)
|
28 |
+
|
29 |
+
print(f"Your samples are ready and waiting four you here: \nresults/test \nEnjoy.")
|
30 |
+
|
31 |
+
|
32 |
+
def my_inference_function(prompt_oir):
|
33 |
+
prompt = {'ori_caption':prompt_oir,'struct_caption':prompt_oir}
|
34 |
+
file_path = infer(prompt)
|
35 |
+
return "test.wav"
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
gradio_interface = gradio.Interface(
|
40 |
+
fn = my_inference_function,
|
41 |
+
inputs = "text",
|
42 |
+
outputs = "audio"
|
43 |
+
)
|
44 |
+
gradio_interface.launch()
|
audiocaps_test_16000_struct.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
configs/audiolcm.yaml
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 3.0e-06
|
3 |
+
target: ldm.models.diffusion.lcm_audio.LCM_audio
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.012
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: image
|
11 |
+
cond_stage_key: caption
|
12 |
+
mel_dim: 20
|
13 |
+
mel_length: 312
|
14 |
+
channels: 0
|
15 |
+
cond_stage_trainable: False
|
16 |
+
conditioning_key: crossattn
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
scale_by_std: true
|
19 |
+
use_lcm: True
|
20 |
+
num_ddim_timesteps: 50
|
21 |
+
w_min: 4
|
22 |
+
w_max: 12
|
23 |
+
ckpt_path: ../ckpt/maa2.ckpt
|
24 |
+
|
25 |
+
use_ema: false
|
26 |
+
scheduler_config:
|
27 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
28 |
+
params:
|
29 |
+
warm_up_steps:
|
30 |
+
- 10000
|
31 |
+
cycle_lengths:
|
32 |
+
- 10000000000000
|
33 |
+
f_start:
|
34 |
+
- 1.0e-06
|
35 |
+
f_max:
|
36 |
+
- 1.0
|
37 |
+
f_min:
|
38 |
+
- 1.0
|
39 |
+
unet_config:
|
40 |
+
target: ldm.modules.diffusionmodules.concatDiT.ConcatDiT2MLP
|
41 |
+
params:
|
42 |
+
in_channels: 20
|
43 |
+
context_dim: 1024
|
44 |
+
hidden_size: 576
|
45 |
+
num_heads: 8
|
46 |
+
depth: 4
|
47 |
+
max_len: 1000
|
48 |
+
first_stage_config:
|
49 |
+
target: ldm.models.autoencoder1d.AutoencoderKL
|
50 |
+
params:
|
51 |
+
embed_dim: 20
|
52 |
+
monitor: val/rec_loss
|
53 |
+
ckpt_path: ./model/AutoencoderKL/epoch=000032.ckpt
|
54 |
+
ddconfig:
|
55 |
+
double_z: true
|
56 |
+
in_channels: 80
|
57 |
+
out_ch: 80
|
58 |
+
z_channels: 20
|
59 |
+
kernel_size: 5
|
60 |
+
ch: 384
|
61 |
+
ch_mult:
|
62 |
+
- 1
|
63 |
+
- 2
|
64 |
+
- 4
|
65 |
+
num_res_blocks: 2
|
66 |
+
attn_layers:
|
67 |
+
- 3
|
68 |
+
down_layers:
|
69 |
+
- 0
|
70 |
+
dropout: 0.0
|
71 |
+
lossconfig:
|
72 |
+
target: torch.nn.Identity
|
73 |
+
cond_stage_config:
|
74 |
+
target: ldm.modules.encoders.modules.FrozenCLAPFLANEmbedder
|
75 |
+
params:
|
76 |
+
weights_path: ./model/FrozenCLAPFLANEmbedder/CLAP_weights_2022.pth
|
77 |
+
|
78 |
+
lightning:
|
79 |
+
callbacks:
|
80 |
+
image_logger:
|
81 |
+
target: main.AudioLogger
|
82 |
+
params:
|
83 |
+
sample_rate: 16000
|
84 |
+
for_specs: true
|
85 |
+
increase_log_steps: false
|
86 |
+
batch_frequency: 5000
|
87 |
+
max_images: 8
|
88 |
+
melvmin: -5
|
89 |
+
melvmax: 1.5
|
90 |
+
vocoder_cfg:
|
91 |
+
target: vocoder.bigvgan.models.VocoderBigVGAN
|
92 |
+
params:
|
93 |
+
ckpt_vocoder: ./vocoder/logs/bigvnat16k93.5w
|
94 |
+
trainer:
|
95 |
+
benchmark: True
|
96 |
+
gradient_clip_val: 1.0
|
97 |
+
replace_sampler_ddp: false
|
98 |
+
max_epochs: 100
|
99 |
+
modelcheckpoint:
|
100 |
+
params:
|
101 |
+
monitor: epoch
|
102 |
+
mode: max
|
103 |
+
# every_n_train_steps: 2000
|
104 |
+
save_top_k: 100
|
105 |
+
every_n_epochs: 3
|
106 |
+
|
107 |
+
|
108 |
+
data:
|
109 |
+
target: main.SpectrogramDataModuleFromConfig
|
110 |
+
params:
|
111 |
+
batch_size: 8
|
112 |
+
num_workers: 32
|
113 |
+
spec_dir_path: 'ldm/data/tsv_dirs/full_data/caps_struct'
|
114 |
+
mel_num: 80
|
115 |
+
train:
|
116 |
+
target: ldm.data.joinaudiodataset_struct_anylen.JoinSpecsTrain
|
117 |
+
params:
|
118 |
+
specs_dataset_cfg:
|
119 |
+
validation:
|
120 |
+
target: ldm.data.joinaudiodataset_struct_anylen.JoinSpecsValidation
|
121 |
+
params:
|
122 |
+
specs_dataset_cfg:
|
123 |
+
|
124 |
+
test_dataset:
|
125 |
+
target: ldm.data.tsvdataset.TSVDatasetStruct
|
126 |
+
params:
|
127 |
+
tsv_path: audiocaps_test_16000_struct.tsv
|
128 |
+
spec_crop_len: 624
|
129 |
+
|
130 |
+
|
configs/autoencoder1d.yaml
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 4.5e-06
|
3 |
+
target: ldm.models.autoencoder1d.AutoencoderKL
|
4 |
+
params:
|
5 |
+
embed_dim: 20
|
6 |
+
monitor: val/rec_loss
|
7 |
+
ddconfig:
|
8 |
+
double_z: true
|
9 |
+
in_channels: 80
|
10 |
+
out_ch: 80
|
11 |
+
z_channels: 20
|
12 |
+
kernel_size: 5
|
13 |
+
ch: 384
|
14 |
+
ch_mult:
|
15 |
+
- 1
|
16 |
+
- 2
|
17 |
+
- 4
|
18 |
+
num_res_blocks: 2
|
19 |
+
attn_layers:
|
20 |
+
- 3
|
21 |
+
down_layers:
|
22 |
+
- 0
|
23 |
+
dropout: 0.0
|
24 |
+
lossconfig:
|
25 |
+
target: ldm.modules.losses_audio.contperceptual.LPAPSWithDiscriminator
|
26 |
+
params:
|
27 |
+
disc_start: 80001
|
28 |
+
perceptual_weight: 0.0
|
29 |
+
kl_weight: 1.0e-06
|
30 |
+
disc_weight: 0.5
|
31 |
+
disc_in_channels: 1
|
32 |
+
disc_loss: mse
|
33 |
+
disc_factor: 2
|
34 |
+
disc_conditional: false
|
35 |
+
r1_reg_weight: 3
|
36 |
+
|
37 |
+
lightning:
|
38 |
+
callbacks:
|
39 |
+
image_logger:
|
40 |
+
target: main.AudioLogger
|
41 |
+
params:
|
42 |
+
for_specs: true
|
43 |
+
increase_log_steps: false
|
44 |
+
batch_frequency: 5000
|
45 |
+
max_images: 8
|
46 |
+
rescale: false
|
47 |
+
melvmin: -5
|
48 |
+
melvmax: 1.5
|
49 |
+
vocoder_cfg:
|
50 |
+
target: vocoder.bigvgan.models.VocoderBigVGAN
|
51 |
+
params:
|
52 |
+
ckpt_vocoder: vocoder/logs/bigvnat16k93.5w
|
53 |
+
trainer:
|
54 |
+
sync_batchnorm: false # not working with r1_regularization
|
55 |
+
strategy: ddp
|
56 |
+
|
57 |
+
|
58 |
+
data:
|
59 |
+
target: main.SpectrogramDataModuleFromConfig
|
60 |
+
params:
|
61 |
+
batch_size: 4
|
62 |
+
num_workers: 16
|
63 |
+
spec_dir_path: ldm/data/tsv_dirs/full_data/V1_new
|
64 |
+
mel_num: 80
|
65 |
+
spec_len: 624
|
66 |
+
spec_crop_len: 624
|
67 |
+
train:
|
68 |
+
target: ldm.data.joinaudiodataset_624.JoinSpecsTrain
|
69 |
+
params:
|
70 |
+
specs_dataset_cfg: null
|
71 |
+
validation:
|
72 |
+
target: ldm.data.joinaudiodataset_624.JoinSpecsValidation
|
73 |
+
params:
|
74 |
+
specs_dataset_cfg: null
|
configs/teacher.yaml
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
base_learning_rate: 3.0e-06
|
3 |
+
target: ldm.models.diffusion.ddpm_audio.LatentDiffusion_audio
|
4 |
+
params:
|
5 |
+
linear_start: 0.00085
|
6 |
+
linear_end: 0.012
|
7 |
+
num_timesteps_cond: 1
|
8 |
+
log_every_t: 200
|
9 |
+
timesteps: 1000
|
10 |
+
first_stage_key: image
|
11 |
+
cond_stage_key: caption
|
12 |
+
mel_dim: 20
|
13 |
+
mel_length: 312
|
14 |
+
channels: 0
|
15 |
+
cond_stage_trainable: True
|
16 |
+
conditioning_key: crossattn
|
17 |
+
monitor: val/loss_simple_ema
|
18 |
+
scale_by_std: true
|
19 |
+
use_ema: false
|
20 |
+
scheduler_config:
|
21 |
+
target: ldm.lr_scheduler.LambdaLinearScheduler
|
22 |
+
params:
|
23 |
+
warm_up_steps:
|
24 |
+
- 10000
|
25 |
+
cycle_lengths:
|
26 |
+
- 10000000000000
|
27 |
+
f_start:
|
28 |
+
- 1.0e-06
|
29 |
+
f_max:
|
30 |
+
- 1.0
|
31 |
+
f_min:
|
32 |
+
- 1.0
|
33 |
+
unet_config:
|
34 |
+
target: ldm.modules.diffusionmodules.concatDiT.ConcatDiT2MLP
|
35 |
+
params:
|
36 |
+
in_channels: 20
|
37 |
+
context_dim: 1024
|
38 |
+
hidden_size: 576
|
39 |
+
num_heads: 8
|
40 |
+
depth: 4
|
41 |
+
max_len: 1000
|
42 |
+
first_stage_config:
|
43 |
+
target: ldm.models.autoencoder1d.AutoencoderKL
|
44 |
+
params:
|
45 |
+
embed_dim: 20
|
46 |
+
monitor: val/rec_loss
|
47 |
+
ckpt_path: logs/trainae/ckpt/epoch=000032.ckpt
|
48 |
+
ddconfig:
|
49 |
+
double_z: true
|
50 |
+
in_channels: 80
|
51 |
+
out_ch: 80
|
52 |
+
z_channels: 20
|
53 |
+
kernel_size: 5
|
54 |
+
ch: 384
|
55 |
+
ch_mult:
|
56 |
+
- 1
|
57 |
+
- 2
|
58 |
+
- 4
|
59 |
+
num_res_blocks: 2
|
60 |
+
attn_layers:
|
61 |
+
- 3
|
62 |
+
down_layers:
|
63 |
+
- 0
|
64 |
+
dropout: 0.0
|
65 |
+
lossconfig:
|
66 |
+
target: torch.nn.Identity
|
67 |
+
cond_stage_config:
|
68 |
+
target: ldm.modules.encoders.modules.FrozenCLAPFLANEmbedder
|
69 |
+
params:
|
70 |
+
weights_path: useful_ckpts/CLAP/CLAP_weights_2022.pth
|
71 |
+
|
72 |
+
lightning:
|
73 |
+
callbacks:
|
74 |
+
image_logger:
|
75 |
+
target: main.AudioLogger
|
76 |
+
params:
|
77 |
+
sample_rate: 16000
|
78 |
+
for_specs: true
|
79 |
+
increase_log_steps: false
|
80 |
+
batch_frequency: 5000
|
81 |
+
max_images: 8
|
82 |
+
melvmin: -5
|
83 |
+
melvmax: 1.5
|
84 |
+
vocoder_cfg:
|
85 |
+
target: vocoder.bigvgan.models.VocoderBigVGAN
|
86 |
+
params:
|
87 |
+
ckpt_vocoder: vocoder/logs/bigvnat16k93.5w
|
88 |
+
trainer:
|
89 |
+
benchmark: True
|
90 |
+
gradient_clip_val: 1.0
|
91 |
+
replace_sampler_ddp: false
|
92 |
+
modelcheckpoint:
|
93 |
+
params:
|
94 |
+
monitor: epoch
|
95 |
+
mode: max
|
96 |
+
save_top_k: 10
|
97 |
+
every_n_epochs: 5
|
98 |
+
|
99 |
+
data:
|
100 |
+
target: main.SpectrogramDataModuleFromConfig
|
101 |
+
params:
|
102 |
+
batch_size: 4
|
103 |
+
num_workers: 32
|
104 |
+
main_spec_dir_path: 'ldm/data/tsv_dirs/full_data/caps_struct'
|
105 |
+
other_spec_dir_path: 'ldm/data/tsv_dirs/full_data/V2'
|
106 |
+
mel_num: 80
|
107 |
+
train:
|
108 |
+
target: ldm.data.joinaudiodataset_struct_sample_anylen.JoinSpecsTrain
|
109 |
+
params:
|
110 |
+
specs_dataset_cfg:
|
111 |
+
validation:
|
112 |
+
target: ldm.data.joinaudiodataset_struct_sample_anylen.JoinSpecsValidation
|
113 |
+
params:
|
114 |
+
specs_dataset_cfg:
|
115 |
+
|
116 |
+
test_dataset:
|
117 |
+
target: ldm.data.tsvdataset.TSVDatasetStruct
|
118 |
+
params:
|
119 |
+
tsv_path: musiccap.tsv
|
120 |
+
spec_crop_len: 624
|
121 |
+
|
infer.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES='1' python scripts/txt2audio_for_lcm.py --n_samples 1 --n_iter 1 --scale 5 --H 20 --W 312 \
|
2 |
+
--ddim_steps 2 -b configs/audiolcm.yaml \
|
3 |
+
--sample_rate 16000 --vocoder-ckpt ../vocoder/logs/bigvnat16k93.5w \
|
4 |
+
--outdir results/test --test-dataset audiocaps -r ../logs/2024-04-21T14-50-11_text2music-audioset-nonoverlap/epoch=000184.ckpt
|
infer_api.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES='1' python scripts/txt2audio_for_lcm.py --n_samples 1 --n_iter 1 --scale 5 --H 20 --W 312 \
|
2 |
+
--ddim_steps 2 -b configs/audiolcm.yaml \
|
3 |
+
--sample_rate 16000 --vocoder-ckpt ../vocoder/logs/bigvnat16k93.5w \
|
4 |
+
--outdir results/test -r ../logs/2024-04-21T14-50-11_text2music-audioset-nonoverlap/epoch=000184.ckpt --prompt_txt ./prompt.txt
|
ldm/__pycache__/lr_scheduler.cpython-37.pyc
ADDED
Binary file (3.66 kB). View file
|
|
ldm/__pycache__/lr_scheduler.cpython-38.pyc
ADDED
Binary file (3.61 kB). View file
|
|
ldm/__pycache__/util.cpython-310.pyc
ADDED
Binary file (8.36 kB). View file
|
|
ldm/__pycache__/util.cpython-37.pyc
ADDED
Binary file (5.1 kB). View file
|
|
ldm/__pycache__/util.cpython-38.pyc
ADDED
Binary file (8.3 kB). View file
|
|
ldm/data/__pycache__/joinaudiodataset_624.cpython-38.pyc
ADDED
Binary file (3.62 kB). View file
|
|
ldm/data/__pycache__/joinaudiodataset_anylen.cpython-37.pyc
ADDED
Binary file (12.4 kB). View file
|
|
ldm/data/__pycache__/joinaudiodataset_anylen.cpython-38.pyc
ADDED
Binary file (12.1 kB). View file
|
|
ldm/data/__pycache__/joinaudiodataset_struct.cpython-38.pyc
ADDED
Binary file (3.69 kB). View file
|
|
ldm/data/__pycache__/joinaudiodataset_struct_anylen.cpython-38.pyc
ADDED
Binary file (12.5 kB). View file
|
|
ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-37.pyc
ADDED
Binary file (8.29 kB). View file
|
|
ldm/data/__pycache__/joinaudiodataset_struct_sample_anylen.cpython-38.pyc
ADDED
Binary file (8.09 kB). View file
|
|
ldm/data/__pycache__/tsvdataset.cpython-38.pyc
ADDED
Binary file (2.66 kB). View file
|
|
ldm/data/joinaudiodataset_624.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import logging
|
5 |
+
import pandas as pd
|
6 |
+
import glob
|
7 |
+
logger = logging.getLogger(f'main.{__name__}')
|
8 |
+
|
9 |
+
sys.path.insert(0, '.') # nopep8
|
10 |
+
|
11 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
12 |
+
def __init__(self, split, spec_dir_path, mel_num=None, spec_crop_len=None,drop=0,**kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.split = split
|
15 |
+
self.batch_max_length = spec_crop_len
|
16 |
+
self.batch_min_length = 50
|
17 |
+
self.mel_num = mel_num
|
18 |
+
self.drop = drop
|
19 |
+
manifest_files = []
|
20 |
+
for dir_path in spec_dir_path.split(','):
|
21 |
+
manifest_files += glob.glob(f'{dir_path}/**/*.tsv',recursive=True)
|
22 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
23 |
+
df = pd.concat(df_list,ignore_index=True)
|
24 |
+
|
25 |
+
if split == 'train':
|
26 |
+
self.dataset = df.iloc[100:]
|
27 |
+
elif split == 'valid' or split == 'val':
|
28 |
+
self.dataset = df.iloc[:100]
|
29 |
+
elif split == 'test':
|
30 |
+
df = self.add_name_num(df)
|
31 |
+
self.dataset = df
|
32 |
+
else:
|
33 |
+
raise ValueError(f'Unknown split {split}')
|
34 |
+
self.dataset.reset_index(inplace=True)
|
35 |
+
print('dataset len:', len(self.dataset))
|
36 |
+
|
37 |
+
def add_name_num(self,df):
|
38 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
39 |
+
name_count_dict = {}
|
40 |
+
change = []
|
41 |
+
for t in df.itertuples():
|
42 |
+
name = getattr(t,'name')
|
43 |
+
if name in name_count_dict:
|
44 |
+
name_count_dict[name] += 1
|
45 |
+
else:
|
46 |
+
name_count_dict[name] = 0
|
47 |
+
change.append((t[0],name_count_dict[name]))
|
48 |
+
for t in change:
|
49 |
+
df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
|
50 |
+
return df
|
51 |
+
|
52 |
+
def __getitem__(self, idx):
|
53 |
+
data = self.dataset.iloc[idx]
|
54 |
+
item = {}
|
55 |
+
try:
|
56 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
57 |
+
except:
|
58 |
+
mel_path = data['mel_path']
|
59 |
+
print(f'corrupted:{mel_path}')
|
60 |
+
spec = np.zeros((self.mel_num,self.batch_max_length)).astype(np.float32)
|
61 |
+
|
62 |
+
if spec.shape[1] < self.batch_max_length:
|
63 |
+
# spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1]))) # [80, 624]
|
64 |
+
spec = np.tile(spec,reps=(self.batch_max_length//spec.shape[1])+1)
|
65 |
+
|
66 |
+
item['image'] = spec[:,:self.batch_max_length]
|
67 |
+
p = np.random.uniform(0,1)
|
68 |
+
if p > self.drop:
|
69 |
+
item["caption"] = data['caption']
|
70 |
+
else:
|
71 |
+
item["caption"] = ""
|
72 |
+
if self.split == 'test':
|
73 |
+
item['f_name'] = data['name']
|
74 |
+
return item
|
75 |
+
|
76 |
+
def __len__(self):
|
77 |
+
return len(self.dataset)
|
78 |
+
|
79 |
+
|
80 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
81 |
+
def __init__(self, specs_dataset_cfg):
|
82 |
+
super().__init__('train', **specs_dataset_cfg)
|
83 |
+
|
84 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
85 |
+
def __init__(self, specs_dataset_cfg):
|
86 |
+
super().__init__('valid', **specs_dataset_cfg)
|
87 |
+
|
88 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
89 |
+
def __init__(self, specs_dataset_cfg):
|
90 |
+
super().__init__('test', **specs_dataset_cfg)
|
91 |
+
|
92 |
+
|
93 |
+
|
ldm/data/joinaudiodataset_anylen.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch.utils.data.sampler import Sampler
|
7 |
+
from torch.utils.data.distributed import DistributedSampler
|
8 |
+
import torch.distributed
|
9 |
+
from typing import TypeVar, Optional, Iterator,List
|
10 |
+
import logging
|
11 |
+
import pandas as pd
|
12 |
+
import glob
|
13 |
+
import torch.distributed as dist
|
14 |
+
logger = logging.getLogger(f'main.{__name__}')
|
15 |
+
|
16 |
+
sys.path.insert(0, '.') # nopep8
|
17 |
+
|
18 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
19 |
+
def __init__(self, split, spec_dir_path, mel_num=80,spec_crop_len=1248,mode='pad',pad_value=-5,drop=0,**kwargs):
|
20 |
+
super().__init__()
|
21 |
+
self.split = split
|
22 |
+
self.max_batch_len = spec_crop_len
|
23 |
+
self.min_batch_len = 64
|
24 |
+
self.mel_num = mel_num
|
25 |
+
self.min_factor = 4
|
26 |
+
self.drop = drop
|
27 |
+
self.pad_value = pad_value
|
28 |
+
assert mode in ['pad','tile']
|
29 |
+
self.collate_mode = mode
|
30 |
+
# print(f"################# self.collate_mode {self.collate_mode} ##################")
|
31 |
+
|
32 |
+
manifest_files = []
|
33 |
+
for dir_path in spec_dir_path.split(','):
|
34 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
35 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
36 |
+
df = pd.concat(df_list,ignore_index=True)
|
37 |
+
|
38 |
+
if split == 'train':
|
39 |
+
self.dataset = df.iloc[100:]
|
40 |
+
elif split == 'valid' or split == 'val':
|
41 |
+
self.dataset = df.iloc[:100]
|
42 |
+
elif split == 'test':
|
43 |
+
df = self.add_name_num(df)
|
44 |
+
self.dataset = df
|
45 |
+
else:
|
46 |
+
raise ValueError(f'Unknown split {split}')
|
47 |
+
self.dataset.reset_index(inplace=True)
|
48 |
+
print('dataset len:', len(self.dataset))
|
49 |
+
|
50 |
+
def add_name_num(self,df):
|
51 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
52 |
+
name_count_dict = {}
|
53 |
+
change = []
|
54 |
+
for t in df.itertuples():
|
55 |
+
name = getattr(t,'name')
|
56 |
+
if name in name_count_dict:
|
57 |
+
name_count_dict[name] += 1
|
58 |
+
else:
|
59 |
+
name_count_dict[name] = 0
|
60 |
+
change.append((t[0],name_count_dict[name]))
|
61 |
+
for t in change:
|
62 |
+
df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
|
63 |
+
return df
|
64 |
+
|
65 |
+
def ordered_indices(self):
|
66 |
+
index2dur = self.dataset[['duration']]
|
67 |
+
index2dur = index2dur.sort_values(by='duration')
|
68 |
+
return list(index2dur.index)
|
69 |
+
|
70 |
+
def __getitem__(self, idx):
|
71 |
+
item = {}
|
72 |
+
data = self.dataset.iloc[idx]
|
73 |
+
try:
|
74 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
75 |
+
except:
|
76 |
+
mel_path = data['mel_path']
|
77 |
+
print(f'corrupted:{mel_path}')
|
78 |
+
spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
|
79 |
+
|
80 |
+
|
81 |
+
item['image'] = spec
|
82 |
+
p = np.random.uniform(0,1)
|
83 |
+
if p > self.drop:
|
84 |
+
item["caption"] = data['caption']
|
85 |
+
else:
|
86 |
+
item["caption"] = ""
|
87 |
+
if self.split == 'test':
|
88 |
+
item['f_name'] = data['name']
|
89 |
+
# item['f_name'] = data['mel_path']
|
90 |
+
return item
|
91 |
+
|
92 |
+
def collater(self,inputs):
|
93 |
+
to_dict = {}
|
94 |
+
for l in inputs:
|
95 |
+
for k,v in l.items():
|
96 |
+
if k in to_dict:
|
97 |
+
to_dict[k].append(v)
|
98 |
+
else:
|
99 |
+
to_dict[k] = [v]
|
100 |
+
if self.collate_mode == 'pad':
|
101 |
+
to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
102 |
+
elif self.collate_mode == 'tile':
|
103 |
+
to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
104 |
+
else:
|
105 |
+
raise NotImplementedError
|
106 |
+
|
107 |
+
return to_dict
|
108 |
+
|
109 |
+
def __len__(self):
|
110 |
+
return len(self.dataset)
|
111 |
+
|
112 |
+
|
113 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
114 |
+
def __init__(self, specs_dataset_cfg):
|
115 |
+
super().__init__('train', **specs_dataset_cfg)
|
116 |
+
|
117 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
118 |
+
def __init__(self, specs_dataset_cfg):
|
119 |
+
super().__init__('valid', **specs_dataset_cfg)
|
120 |
+
|
121 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
122 |
+
def __init__(self, specs_dataset_cfg):
|
123 |
+
super().__init__('test', **specs_dataset_cfg)
|
124 |
+
|
125 |
+
class JoinSpecsDebug(JoinManifestSpecs):
|
126 |
+
def __init__(self, specs_dataset_cfg):
|
127 |
+
super().__init__('valid', **specs_dataset_cfg)
|
128 |
+
self.dataset = self.dataset.iloc[:37]
|
129 |
+
|
130 |
+
class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到一个batch中以避免过长的pad
|
131 |
+
def __init__(self, indices ,batch_size, num_replicas: Optional[int] = None,
|
132 |
+
rank: Optional[int] = None, shuffle: bool = True,
|
133 |
+
seed: int = 0, drop_last: bool = False) -> None:
|
134 |
+
if num_replicas is None:
|
135 |
+
if not dist.is_initialized():
|
136 |
+
# raise RuntimeError("Requires distributed package to be available")
|
137 |
+
print("Not in distributed mode")
|
138 |
+
num_replicas = 1
|
139 |
+
else:
|
140 |
+
num_replicas = dist.get_world_size()
|
141 |
+
if rank is None:
|
142 |
+
if not dist.is_initialized():
|
143 |
+
# raise RuntimeError("Requires distributed package to be available")
|
144 |
+
rank = 0
|
145 |
+
else:
|
146 |
+
rank = dist.get_rank()
|
147 |
+
if rank >= num_replicas or rank < 0:
|
148 |
+
raise ValueError(
|
149 |
+
"Invalid rank {}, rank should be in the interval"
|
150 |
+
" [0, {}]".format(rank, num_replicas - 1))
|
151 |
+
self.indices = indices
|
152 |
+
self.num_replicas = num_replicas
|
153 |
+
self.rank = rank
|
154 |
+
self.epoch = 0
|
155 |
+
self.drop_last = drop_last
|
156 |
+
self.batch_size = batch_size
|
157 |
+
|
158 |
+
self.batches = self.build_batches()
|
159 |
+
print(f"rank: {self.rank}, batches_num {len(self.batches)}")
|
160 |
+
# If the dataset length is evenly divisible by replicas, then there
|
161 |
+
# is no need to drop any data, since the dataset will be split equally.
|
162 |
+
if self.drop_last and len(self.batches) % self.num_replicas != 0:
|
163 |
+
self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
|
164 |
+
if len(self.batches) > self.num_replicas:
|
165 |
+
self.batches = self.batches[self.rank::self.num_replicas]
|
166 |
+
else: # may happen in sanity checking
|
167 |
+
self.batches = [self.batches[0]]
|
168 |
+
print(f"after split batches_num {len(self.batches)}")
|
169 |
+
self.shuffle = shuffle
|
170 |
+
if self.shuffle:
|
171 |
+
self.batches = np.random.permutation(self.batches)
|
172 |
+
self.seed = seed
|
173 |
+
|
174 |
+
def set_epoch(self,epoch):
|
175 |
+
self.epoch = epoch
|
176 |
+
if self.shuffle:
|
177 |
+
np.random.seed(self.seed+self.epoch)
|
178 |
+
self.batches = np.random.permutation(self.batches)
|
179 |
+
|
180 |
+
def build_batches(self):
|
181 |
+
batches,batch = [],[]
|
182 |
+
for index in self.indices:
|
183 |
+
batch.append(index)
|
184 |
+
if len(batch) == self.batch_size:
|
185 |
+
batches.append(batch)
|
186 |
+
batch = []
|
187 |
+
if not self.drop_last and len(batch) > 0:
|
188 |
+
batches.append(batch)
|
189 |
+
return batches
|
190 |
+
|
191 |
+
def __iter__(self) -> Iterator[List[int]]:
|
192 |
+
for batch in self.batches:
|
193 |
+
yield batch
|
194 |
+
|
195 |
+
def __len__(self) -> int:
|
196 |
+
return len(self.batches)
|
197 |
+
|
198 |
+
def set_epoch(self, epoch: int) -> None:
|
199 |
+
r"""
|
200 |
+
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
201 |
+
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
202 |
+
sampler will yield the same ordering.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
epoch (int): Epoch number.
|
206 |
+
"""
|
207 |
+
self.epoch = epoch
|
208 |
+
|
209 |
+
|
210 |
+
def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
|
211 |
+
if len(values[0].shape) == 1:
|
212 |
+
return collate_1d(values, pad_idx, left_pad, shift_right,min_len, max_len,min_factor, shift_id)
|
213 |
+
else:
|
214 |
+
return collate_2d(values, pad_idx, left_pad, shift_right,min_len,max_len,min_factor)
|
215 |
+
|
216 |
+
def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False,min_len=None, max_len=None,min_factor=None, shift_id=1):
|
217 |
+
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
218 |
+
size = max(v.size(0) for v in values)
|
219 |
+
if max_len:
|
220 |
+
size = min(size,max_len)
|
221 |
+
if min_len:
|
222 |
+
size = max(size,min_len)
|
223 |
+
if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
|
224 |
+
size += (min_factor - size % min_factor)
|
225 |
+
res = values[0].new(len(values), size).fill_(pad_idx)
|
226 |
+
|
227 |
+
def copy_tensor(src, dst):
|
228 |
+
assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
|
229 |
+
if shift_right:
|
230 |
+
dst[1:] = src[:-1]
|
231 |
+
dst[0] = shift_id
|
232 |
+
else:
|
233 |
+
dst.copy_(src)
|
234 |
+
|
235 |
+
for i, v in enumerate(values):
|
236 |
+
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
|
237 |
+
return res
|
238 |
+
|
239 |
+
|
240 |
+
def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, min_len=None,max_len=None,min_factor=None):
|
241 |
+
"""Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension.
|
242 |
+
values[0] shape: (melbins,mel_length)
|
243 |
+
"""
|
244 |
+
size = max(v.shape[1] for v in values) # if max_len is None else max_len
|
245 |
+
if max_len:
|
246 |
+
size = min(size,max_len)
|
247 |
+
if min_len:
|
248 |
+
size = max(size,min_len)
|
249 |
+
if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
|
250 |
+
size += (min_factor - size % min_factor)
|
251 |
+
|
252 |
+
if isinstance(values,np.ndarray):
|
253 |
+
values = torch.FloatTensor(values)
|
254 |
+
if isinstance(values,list):
|
255 |
+
values = [torch.FloatTensor(v) for v in values]
|
256 |
+
res = torch.ones(len(values), values[0].shape[0],size).to(dtype=torch.float32)*pad_idx
|
257 |
+
|
258 |
+
def copy_tensor(src, dst):
|
259 |
+
assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
|
260 |
+
if shift_right:
|
261 |
+
dst[1:] = src[:-1]
|
262 |
+
else:
|
263 |
+
dst.copy_(src)
|
264 |
+
|
265 |
+
for i, v in enumerate(values):
|
266 |
+
copy_tensor(v[:,:size], res[i][:,size - v.shape[1]:] if left_pad else res[i][:,:v.shape[1]])
|
267 |
+
return res
|
268 |
+
|
269 |
+
|
270 |
+
def collate_1d_or_2d_tile(values, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
|
271 |
+
if len(values[0].shape) == 1:
|
272 |
+
return collate_1d_tile(values, shift_right,min_len, max_len,min_factor, shift_id)
|
273 |
+
else:
|
274 |
+
return collate_2d_tile(values, shift_right,min_len,max_len,min_factor)
|
275 |
+
|
276 |
+
def collate_1d_tile(values, shift_right=False,min_len=None, max_len=None,min_factor=None,shift_id=1):
|
277 |
+
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
278 |
+
size = max(v.size(0) for v in values)
|
279 |
+
if max_len:
|
280 |
+
size = min(size,max_len)
|
281 |
+
if min_len:
|
282 |
+
size = max(size,min_len)
|
283 |
+
if min_factor and (size%min_factor!=0):# size must be the multiple of min_factor
|
284 |
+
size += (min_factor - size % min_factor)
|
285 |
+
res = values[0].new(len(values), size)
|
286 |
+
|
287 |
+
def copy_tensor(src, dst):
|
288 |
+
assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
|
289 |
+
if shift_right:
|
290 |
+
dst[1:] = src[:-1]
|
291 |
+
dst[0] = shift_id
|
292 |
+
else:
|
293 |
+
dst.copy_(src)
|
294 |
+
|
295 |
+
for i, v in enumerate(values):
|
296 |
+
n_repeat = math.ceil((size + 1) / v.shape[0])
|
297 |
+
v = torch.tile(v,dims=(1,n_repeat))[:size]
|
298 |
+
copy_tensor(v, res[i])
|
299 |
+
|
300 |
+
return res
|
301 |
+
|
302 |
+
|
303 |
+
def collate_2d_tile(values, shift_right=False, min_len=None,max_len=None,min_factor=None):
|
304 |
+
"""Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension. """
|
305 |
+
size = max(v.shape[1] for v in values) # if max_len is None else max_len
|
306 |
+
if max_len:
|
307 |
+
size = min(size,max_len)
|
308 |
+
if min_len:
|
309 |
+
size = max(size,min_len)
|
310 |
+
if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
|
311 |
+
size += (min_factor - size % min_factor)
|
312 |
+
|
313 |
+
if isinstance(values,np.ndarray):
|
314 |
+
values = torch.FloatTensor(values)
|
315 |
+
if isinstance(values,list):
|
316 |
+
values = [torch.FloatTensor(v) for v in values]
|
317 |
+
res = torch.zeros(len(values), values[0].shape[0],size).to(dtype=torch.float32)
|
318 |
+
|
319 |
+
def copy_tensor(src, dst):
|
320 |
+
assert dst.numel() == src.numel()
|
321 |
+
if shift_right:
|
322 |
+
dst[1:] = src[:-1]
|
323 |
+
else:
|
324 |
+
dst.copy_(src)
|
325 |
+
|
326 |
+
for i, v in enumerate(values):
|
327 |
+
n_repeat = math.ceil((size + 1) / v.shape[1])
|
328 |
+
v = torch.tile(v,dims=(1,n_repeat))[:,:size]
|
329 |
+
copy_tensor(v, res[i])
|
330 |
+
|
331 |
+
return res
|
ldm/data/joinaudiodataset_struct.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import logging
|
5 |
+
import pandas as pd
|
6 |
+
import glob
|
7 |
+
logger = logging.getLogger(f'main.{__name__}')
|
8 |
+
|
9 |
+
sys.path.insert(0, '.') # nopep8
|
10 |
+
|
11 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
12 |
+
def __init__(self, split, spec_dir_path, mel_num=None, spec_crop_len=None,drop=0,**kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.split = split
|
15 |
+
self.batch_max_length = spec_crop_len
|
16 |
+
self.batch_min_length = 50
|
17 |
+
self.drop = drop
|
18 |
+
self.mel_num = mel_num
|
19 |
+
|
20 |
+
manifest_files = []
|
21 |
+
for dir_path in spec_dir_path.split(','):
|
22 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
23 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
24 |
+
df = pd.concat(df_list,ignore_index=True)
|
25 |
+
|
26 |
+
if split == 'train':
|
27 |
+
self.dataset = df.iloc[100:]
|
28 |
+
elif split == 'valid' or split == 'val':
|
29 |
+
self.dataset = df.iloc[:100]
|
30 |
+
elif split == 'test':
|
31 |
+
df = self.add_name_num(df)
|
32 |
+
self.dataset = df
|
33 |
+
else:
|
34 |
+
raise ValueError(f'Unknown split {split}')
|
35 |
+
self.dataset.reset_index(inplace=True)
|
36 |
+
print('dataset len:', len(self.dataset))
|
37 |
+
|
38 |
+
def add_name_num(self,df):
|
39 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
40 |
+
name_count_dict = {}
|
41 |
+
change = []
|
42 |
+
for t in df.itertuples():
|
43 |
+
name = getattr(t,'name')
|
44 |
+
if name in name_count_dict:
|
45 |
+
name_count_dict[name] += 1
|
46 |
+
else:
|
47 |
+
name_count_dict[name] = 0
|
48 |
+
change.append((t[0],name_count_dict[name]))
|
49 |
+
for t in change:
|
50 |
+
df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
|
51 |
+
return df
|
52 |
+
|
53 |
+
def __getitem__(self, idx):
|
54 |
+
data = self.dataset.iloc[idx]
|
55 |
+
item = {}
|
56 |
+
try:
|
57 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
58 |
+
except:
|
59 |
+
mel_path = data['mel_path']
|
60 |
+
print(f'corrupted:{mel_path}')
|
61 |
+
spec = np.zeros((self.mel_num,self.batch_max_length)).astype(np.float32)
|
62 |
+
|
63 |
+
if spec.shape[1] <= self.batch_max_length:
|
64 |
+
spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1]))) # [80, 624]
|
65 |
+
|
66 |
+
|
67 |
+
item['image'] = spec[:self.mel_num,:self.batch_max_length]
|
68 |
+
p = np.random.uniform(0,1)
|
69 |
+
if p > self.drop:
|
70 |
+
item["caption"] = {"ori_caption":data['ori_cap'],"struct_caption":data['caption']}
|
71 |
+
else:
|
72 |
+
item["caption"] = {"ori_caption":"","struct_caption":""}
|
73 |
+
|
74 |
+
if self.split == 'test':
|
75 |
+
item['f_name'] = data['name']
|
76 |
+
return item
|
77 |
+
|
78 |
+
def __len__(self):
|
79 |
+
return len(self.dataset)
|
80 |
+
|
81 |
+
|
82 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
83 |
+
def __init__(self, specs_dataset_cfg):
|
84 |
+
super().__init__('train', **specs_dataset_cfg)
|
85 |
+
|
86 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
87 |
+
def __init__(self, specs_dataset_cfg):
|
88 |
+
super().__init__('valid', **specs_dataset_cfg)
|
89 |
+
|
90 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
91 |
+
def __init__(self, specs_dataset_cfg):
|
92 |
+
super().__init__('test', **specs_dataset_cfg)
|
93 |
+
|
94 |
+
|
95 |
+
|
ldm/data/joinaudiodataset_struct_anylen.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch.utils.data.sampler import Sampler
|
7 |
+
from torch.utils.data.distributed import DistributedSampler
|
8 |
+
import torch.distributed
|
9 |
+
from typing import TypeVar, Optional, Iterator,List
|
10 |
+
import logging
|
11 |
+
import pandas as pd
|
12 |
+
import glob
|
13 |
+
import torch.distributed as dist
|
14 |
+
logger = logging.getLogger(f'main.{__name__}')
|
15 |
+
|
16 |
+
sys.path.insert(0, '.') # nopep8
|
17 |
+
|
18 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
19 |
+
def __init__(self, split, spec_dir_path, mel_num=80,spec_crop_len=1248,mode='pad',pad_value=-5,drop=0,**kwargs):
|
20 |
+
super().__init__()
|
21 |
+
self.split = split
|
22 |
+
self.max_batch_len = spec_crop_len
|
23 |
+
self.min_batch_len = 64
|
24 |
+
self.mel_num = mel_num
|
25 |
+
self.min_factor = 4
|
26 |
+
self.drop = drop
|
27 |
+
self.pad_value = pad_value
|
28 |
+
assert mode in ['pad','tile']
|
29 |
+
self.collate_mode = mode
|
30 |
+
# print(f"################# self.collate_mode {self.collate_mode} ##################")
|
31 |
+
|
32 |
+
manifest_files = []
|
33 |
+
for dir_path in spec_dir_path.split(','):
|
34 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
35 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
36 |
+
df = pd.concat(df_list,ignore_index=True)
|
37 |
+
|
38 |
+
if split == 'train':
|
39 |
+
self.dataset = df.iloc[100:]
|
40 |
+
elif split == 'valid' or split == 'val':
|
41 |
+
self.dataset = df.iloc[:100]
|
42 |
+
elif split == 'test':
|
43 |
+
df = self.add_name_num(df)
|
44 |
+
self.dataset = df
|
45 |
+
else:
|
46 |
+
raise ValueError(f'Unknown split {split}')
|
47 |
+
self.dataset.reset_index(inplace=True)
|
48 |
+
print('dataset len:', len(self.dataset))
|
49 |
+
|
50 |
+
def add_name_num(self,df):
|
51 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
52 |
+
name_count_dict = {}
|
53 |
+
change = []
|
54 |
+
for t in df.itertuples():
|
55 |
+
name = getattr(t,'name')
|
56 |
+
if name in name_count_dict:
|
57 |
+
name_count_dict[name] += 1
|
58 |
+
else:
|
59 |
+
name_count_dict[name] = 0
|
60 |
+
change.append((t[0],name_count_dict[name]))
|
61 |
+
for t in change:
|
62 |
+
df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
|
63 |
+
return df
|
64 |
+
|
65 |
+
def ordered_indices(self):
|
66 |
+
index2dur = self.dataset[['duration']]
|
67 |
+
index2dur = index2dur.sort_values(by='duration')
|
68 |
+
return list(index2dur.index)
|
69 |
+
|
70 |
+
def __getitem__(self, idx):
|
71 |
+
item = {}
|
72 |
+
data = self.dataset.iloc[idx]
|
73 |
+
try:
|
74 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
75 |
+
except:
|
76 |
+
mel_path = data['mel_path']
|
77 |
+
print(f'corrupted:{mel_path}')
|
78 |
+
spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
|
79 |
+
|
80 |
+
|
81 |
+
item['image'] = spec
|
82 |
+
p = np.random.uniform(0,1)
|
83 |
+
if p > self.drop:
|
84 |
+
ori_caption = data['caption']
|
85 |
+
struct_caption = f'<{ori_caption}& all>'
|
86 |
+
else:
|
87 |
+
ori_caption = ""
|
88 |
+
struct_caption = ""
|
89 |
+
item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
|
90 |
+
if self.split == 'test':
|
91 |
+
item['f_name'] = data['name']
|
92 |
+
# item['f_name'] = data['mel_path']
|
93 |
+
return item
|
94 |
+
|
95 |
+
def collater(self,inputs):
|
96 |
+
to_dict = {}
|
97 |
+
for l in inputs:
|
98 |
+
for k,v in l.items():
|
99 |
+
if k in to_dict:
|
100 |
+
to_dict[k].append(v)
|
101 |
+
else:
|
102 |
+
to_dict[k] = [v]
|
103 |
+
if self.collate_mode == 'pad':
|
104 |
+
to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
105 |
+
elif self.collate_mode == 'tile':
|
106 |
+
to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
107 |
+
else:
|
108 |
+
raise NotImplementedError
|
109 |
+
to_dict['caption'] = {'ori_caption':[c['ori_caption'] for c in to_dict['caption']],
|
110 |
+
'struct_caption':[c['struct_caption'] for c in to_dict['caption']]}
|
111 |
+
|
112 |
+
return to_dict
|
113 |
+
|
114 |
+
def __len__(self):
|
115 |
+
return len(self.dataset)
|
116 |
+
|
117 |
+
|
118 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
119 |
+
def __init__(self, specs_dataset_cfg):
|
120 |
+
super().__init__('train', **specs_dataset_cfg)
|
121 |
+
|
122 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
123 |
+
def __init__(self, specs_dataset_cfg):
|
124 |
+
super().__init__('valid', **specs_dataset_cfg)
|
125 |
+
|
126 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
127 |
+
def __init__(self, specs_dataset_cfg):
|
128 |
+
super().__init__('test', **specs_dataset_cfg)
|
129 |
+
|
130 |
+
class JoinSpecsDebug(JoinManifestSpecs):
|
131 |
+
def __init__(self, specs_dataset_cfg):
|
132 |
+
super().__init__('valid', **specs_dataset_cfg)
|
133 |
+
self.dataset = self.dataset.iloc[:37]
|
134 |
+
|
135 |
+
class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到��个batch中以避免过长的pad
|
136 |
+
def __init__(self, indices ,batch_size, num_replicas: Optional[int] = None,
|
137 |
+
rank: Optional[int] = None, shuffle: bool = True,
|
138 |
+
seed: int = 0, drop_last: bool = False) -> None:
|
139 |
+
if num_replicas is None:
|
140 |
+
if not dist.is_initialized():
|
141 |
+
# raise RuntimeError("Requires distributed package to be available")
|
142 |
+
print("Not in distributed mode")
|
143 |
+
num_replicas = 1
|
144 |
+
else:
|
145 |
+
num_replicas = dist.get_world_size()
|
146 |
+
if rank is None:
|
147 |
+
if not dist.is_initialized():
|
148 |
+
# raise RuntimeError("Requires distributed package to be available")
|
149 |
+
rank = 0
|
150 |
+
else:
|
151 |
+
rank = dist.get_rank()
|
152 |
+
if rank >= num_replicas or rank < 0:
|
153 |
+
raise ValueError(
|
154 |
+
"Invalid rank {}, rank should be in the interval"
|
155 |
+
" [0, {}]".format(rank, num_replicas - 1))
|
156 |
+
self.indices = indices
|
157 |
+
self.num_replicas = num_replicas
|
158 |
+
self.rank = rank
|
159 |
+
self.epoch = 0
|
160 |
+
self.drop_last = drop_last
|
161 |
+
self.batch_size = batch_size
|
162 |
+
|
163 |
+
self.batches = self.build_batches()
|
164 |
+
print(f"rank: {self.rank}, batches_num {len(self.batches)}")
|
165 |
+
# If the dataset length is evenly divisible by replicas, then there
|
166 |
+
# is no need to drop any data, since the dataset will be split equally.
|
167 |
+
if self.drop_last and len(self.batches) % self.num_replicas != 0:
|
168 |
+
self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
|
169 |
+
if len(self.batches) > self.num_replicas:
|
170 |
+
self.batches = self.batches[self.rank::self.num_replicas]
|
171 |
+
else: # may happen in sanity checking
|
172 |
+
self.batches = [self.batches[0]]
|
173 |
+
print(f"after split batches_num {len(self.batches)}")
|
174 |
+
self.shuffle = shuffle
|
175 |
+
if self.shuffle:
|
176 |
+
self.batches = np.random.permutation(self.batches)
|
177 |
+
self.seed = seed
|
178 |
+
|
179 |
+
def set_epoch(self,epoch):
|
180 |
+
self.epoch = epoch
|
181 |
+
if self.shuffle:
|
182 |
+
np.random.seed(self.seed+self.epoch)
|
183 |
+
self.batches = np.random.permutation(self.batches)
|
184 |
+
|
185 |
+
def build_batches(self):
|
186 |
+
batches,batch = [],[]
|
187 |
+
for index in self.indices:
|
188 |
+
batch.append(index)
|
189 |
+
if len(batch) == self.batch_size:
|
190 |
+
batches.append(batch)
|
191 |
+
batch = []
|
192 |
+
if not self.drop_last and len(batch) > 0:
|
193 |
+
batches.append(batch)
|
194 |
+
return batches
|
195 |
+
|
196 |
+
def __iter__(self) -> Iterator[List[int]]:
|
197 |
+
for batch in self.batches:
|
198 |
+
yield batch
|
199 |
+
|
200 |
+
def __len__(self) -> int:
|
201 |
+
return len(self.batches)
|
202 |
+
|
203 |
+
def set_epoch(self, epoch: int) -> None:
|
204 |
+
r"""
|
205 |
+
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
|
206 |
+
use a different random ordering for each epoch. Otherwise, the next iteration of this
|
207 |
+
sampler will yield the same ordering.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
epoch (int): Epoch number.
|
211 |
+
"""
|
212 |
+
self.epoch = epoch
|
213 |
+
|
214 |
+
|
215 |
+
def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
|
216 |
+
if len(values[0].shape) == 1:
|
217 |
+
return collate_1d(values, pad_idx, left_pad, shift_right,min_len, max_len,min_factor, shift_id)
|
218 |
+
else:
|
219 |
+
return collate_2d(values, pad_idx, left_pad, shift_right,min_len,max_len,min_factor)
|
220 |
+
|
221 |
+
def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False,min_len=None, max_len=None,min_factor=None, shift_id=1):
|
222 |
+
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
223 |
+
size = max(v.size(0) for v in values)
|
224 |
+
if max_len:
|
225 |
+
size = min(size,max_len)
|
226 |
+
if min_len:
|
227 |
+
size = max(size,min_len)
|
228 |
+
if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
|
229 |
+
size += (min_factor - size % min_factor)
|
230 |
+
res = values[0].new(len(values), size).fill_(pad_idx)
|
231 |
+
|
232 |
+
def copy_tensor(src, dst):
|
233 |
+
assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
|
234 |
+
if shift_right:
|
235 |
+
dst[1:] = src[:-1]
|
236 |
+
dst[0] = shift_id
|
237 |
+
else:
|
238 |
+
dst.copy_(src)
|
239 |
+
|
240 |
+
for i, v in enumerate(values):
|
241 |
+
copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
|
242 |
+
return res
|
243 |
+
|
244 |
+
|
245 |
+
def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, min_len=None,max_len=None,min_factor=None):
|
246 |
+
"""Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension.
|
247 |
+
values[0] shape: (melbins,mel_length)
|
248 |
+
"""
|
249 |
+
size = max(v.shape[1] for v in values) # if max_len is None else max_len
|
250 |
+
if max_len:
|
251 |
+
size = min(size,max_len)
|
252 |
+
if min_len:
|
253 |
+
size = max(size,min_len)
|
254 |
+
if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
|
255 |
+
size += (min_factor - size % min_factor)
|
256 |
+
|
257 |
+
if isinstance(values,np.ndarray):
|
258 |
+
values = torch.FloatTensor(values)
|
259 |
+
if isinstance(values,list):
|
260 |
+
values = [torch.FloatTensor(v) for v in values]
|
261 |
+
res = torch.ones(len(values), values[0].shape[0],size).to(dtype=torch.float32)*pad_idx
|
262 |
+
|
263 |
+
def copy_tensor(src, dst):
|
264 |
+
assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
|
265 |
+
if shift_right:
|
266 |
+
dst[1:] = src[:-1]
|
267 |
+
else:
|
268 |
+
dst.copy_(src)
|
269 |
+
|
270 |
+
for i, v in enumerate(values):
|
271 |
+
copy_tensor(v[:,:size], res[i][:,size - v.shape[1]:] if left_pad else res[i][:,:v.shape[1]])
|
272 |
+
return res
|
273 |
+
|
274 |
+
|
275 |
+
def collate_1d_or_2d_tile(values, shift_right=False,min_len = None, max_len=None,min_factor=None, shift_id=1):
|
276 |
+
if len(values[0].shape) == 1:
|
277 |
+
return collate_1d_tile(values, shift_right,min_len, max_len,min_factor, shift_id)
|
278 |
+
else:
|
279 |
+
return collate_2d_tile(values, shift_right,min_len,max_len,min_factor)
|
280 |
+
|
281 |
+
def collate_1d_tile(values, shift_right=False,min_len=None, max_len=None,min_factor=None,shift_id=1):
|
282 |
+
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
283 |
+
size = max(v.size(0) for v in values)
|
284 |
+
if max_len:
|
285 |
+
size = min(size,max_len)
|
286 |
+
if min_len:
|
287 |
+
size = max(size,min_len)
|
288 |
+
if min_factor and (size%min_factor!=0):# size must be the multiple of min_factor
|
289 |
+
size += (min_factor - size % min_factor)
|
290 |
+
res = values[0].new(len(values), size)
|
291 |
+
|
292 |
+
def copy_tensor(src, dst):
|
293 |
+
assert dst.numel() == src.numel(), f"dst shape:{dst.shape} src shape:{src.shape}"
|
294 |
+
if shift_right:
|
295 |
+
dst[1:] = src[:-1]
|
296 |
+
dst[0] = shift_id
|
297 |
+
else:
|
298 |
+
dst.copy_(src)
|
299 |
+
|
300 |
+
for i, v in enumerate(values):
|
301 |
+
n_repeat = math.ceil((size + 1) / v.shape[0])
|
302 |
+
v = torch.tile(v,dims=(1,n_repeat))[:size]
|
303 |
+
copy_tensor(v, res[i])
|
304 |
+
|
305 |
+
return res
|
306 |
+
|
307 |
+
|
308 |
+
def collate_2d_tile(values, shift_right=False, min_len=None,max_len=None,min_factor=None):
|
309 |
+
"""Collate 2d for melspec,Convert a list of 2d tensors into a padded 3d tensor,pad in mel_length dimension. """
|
310 |
+
size = max(v.shape[1] for v in values) # if max_len is None else max_len
|
311 |
+
if max_len:
|
312 |
+
size = min(size,max_len)
|
313 |
+
if min_len:
|
314 |
+
size = max(size,min_len)
|
315 |
+
if min_factor and (size % min_factor!=0):# size must be the multiple of min_factor
|
316 |
+
size += (min_factor - size % min_factor)
|
317 |
+
|
318 |
+
if isinstance(values,np.ndarray):
|
319 |
+
values = torch.FloatTensor(values)
|
320 |
+
if isinstance(values,list):
|
321 |
+
values = [torch.FloatTensor(v) for v in values]
|
322 |
+
res = torch.zeros(len(values), values[0].shape[0],size).to(dtype=torch.float32)
|
323 |
+
|
324 |
+
def copy_tensor(src, dst):
|
325 |
+
assert dst.numel() == src.numel()
|
326 |
+
if shift_right:
|
327 |
+
dst[1:] = src[:-1]
|
328 |
+
else:
|
329 |
+
dst.copy_(src)
|
330 |
+
|
331 |
+
for i, v in enumerate(values):
|
332 |
+
n_repeat = math.ceil((size + 1) / v.shape[1])
|
333 |
+
v = torch.tile(v,dims=(1,n_repeat))[:,:size]
|
334 |
+
copy_tensor(v, res[i])
|
335 |
+
|
336 |
+
return res
|
ldm/data/joinaudiodataset_struct_sample.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import logging
|
5 |
+
import pandas as pd
|
6 |
+
import glob
|
7 |
+
logger = logging.getLogger(f'main.{__name__}')
|
8 |
+
|
9 |
+
sys.path.insert(0, '.') # nopep8
|
10 |
+
|
11 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
12 |
+
def __init__(self, split, main_spec_dir_path,other_spec_dir_path, mel_num=None, spec_crop_len=None,pad_value=-5,**kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.main_prob = 0.5
|
15 |
+
self.split = split
|
16 |
+
self.batch_max_length = spec_crop_len
|
17 |
+
self.batch_min_length = 50
|
18 |
+
self.mel_num = mel_num
|
19 |
+
self.pad_value = pad_value
|
20 |
+
manifest_files = []
|
21 |
+
for dir_path in main_spec_dir_path.split(','):
|
22 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
23 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
24 |
+
self.df_main = pd.concat(df_list,ignore_index=True)
|
25 |
+
|
26 |
+
manifest_files = []
|
27 |
+
for dir_path in other_spec_dir_path.split(','):
|
28 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
29 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
30 |
+
self.df_other = pd.concat(df_list,ignore_index=True)
|
31 |
+
|
32 |
+
if split == 'train':
|
33 |
+
self.dataset = self.df_main.iloc[100:]
|
34 |
+
elif split == 'valid' or split == 'val':
|
35 |
+
self.dataset = self.df_main.iloc[:100]
|
36 |
+
elif split == 'test':
|
37 |
+
self.df_main = self.add_name_num(self.df_main)
|
38 |
+
self.dataset = self.df_main
|
39 |
+
else:
|
40 |
+
raise ValueError(f'Unknown split {split}')
|
41 |
+
self.dataset.reset_index(inplace=True)
|
42 |
+
print('dataset len:', len(self.dataset))
|
43 |
+
|
44 |
+
def add_name_num(self,df):
|
45 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
46 |
+
name_count_dict = {}
|
47 |
+
change = []
|
48 |
+
for t in df.itertuples():
|
49 |
+
name = getattr(t,'name')
|
50 |
+
if name in name_count_dict:
|
51 |
+
name_count_dict[name] += 1
|
52 |
+
else:
|
53 |
+
name_count_dict[name] = 0
|
54 |
+
change.append((t[0],name_count_dict[name]))
|
55 |
+
for t in change:
|
56 |
+
df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
|
57 |
+
return df
|
58 |
+
|
59 |
+
def __getitem__(self, idx):
|
60 |
+
if np.random.uniform(0,1) < self.main_prob:
|
61 |
+
data = self.dataset.iloc[idx]
|
62 |
+
ori_caption = data['ori_cap']
|
63 |
+
struct_caption = data['caption']
|
64 |
+
else:
|
65 |
+
randidx = np.random.randint(0,len(self.df_other))
|
66 |
+
data = self.df_other.iloc[randidx]
|
67 |
+
ori_caption = data['caption']
|
68 |
+
struct_caption = f'<{ori_caption}, all>'
|
69 |
+
item = {}
|
70 |
+
try:
|
71 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
72 |
+
except:
|
73 |
+
mel_path = data['mel_path']
|
74 |
+
print(f'corrupted:{mel_path}')
|
75 |
+
spec = np.ones((self.mel_num,self.batch_max_length)).astype(np.float32)*self.pad_value
|
76 |
+
|
77 |
+
if spec.shape[1] <= self.batch_max_length:
|
78 |
+
spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1])),mode='constant',constant_values = (self.pad_value,self.pad_value)) # [80, 624]
|
79 |
+
|
80 |
+
item['image'] = spec[:self.mel_num,:self.batch_max_length]
|
81 |
+
item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
|
82 |
+
if self.split == 'test':
|
83 |
+
item['f_name'] = data['name']
|
84 |
+
return item
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return len(self.dataset)
|
88 |
+
|
89 |
+
|
90 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
91 |
+
def __init__(self, specs_dataset_cfg):
|
92 |
+
super().__init__('train', **specs_dataset_cfg)
|
93 |
+
|
94 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
95 |
+
def __init__(self, specs_dataset_cfg):
|
96 |
+
super().__init__('valid', **specs_dataset_cfg)
|
97 |
+
|
98 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
99 |
+
def __init__(self, specs_dataset_cfg):
|
100 |
+
super().__init__('test', **specs_dataset_cfg)
|
101 |
+
|
102 |
+
|
103 |
+
|
ldm/data/joinaudiodataset_struct_sample_anylen.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from typing import TypeVar, Optional, Iterator
|
5 |
+
import logging
|
6 |
+
import pandas as pd
|
7 |
+
from ldm.data.joinaudiodataset_anylen import *
|
8 |
+
import glob
|
9 |
+
logger = logging.getLogger(f'main.{__name__}')
|
10 |
+
|
11 |
+
sys.path.insert(0, '.') # nopep8
|
12 |
+
|
13 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
14 |
+
def __init__(self, split, main_spec_dir_path,other_spec_dir_path, mel_num=80,mode='pad', spec_crop_len=1248,pad_value=-5,drop=0,**kwargs):
|
15 |
+
super().__init__()
|
16 |
+
self.split = split
|
17 |
+
self.max_batch_len = spec_crop_len
|
18 |
+
self.min_batch_len = 64
|
19 |
+
self.min_factor = 4
|
20 |
+
self.mel_num = mel_num
|
21 |
+
self.drop = drop
|
22 |
+
self.pad_value = pad_value
|
23 |
+
assert mode in ['pad','tile']
|
24 |
+
self.collate_mode = mode
|
25 |
+
manifest_files = []
|
26 |
+
|
27 |
+
for dir_path in main_spec_dir_path.split(','):
|
28 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
29 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
30 |
+
self.df_main = pd.concat(df_list,ignore_index=True)
|
31 |
+
|
32 |
+
manifest_files = []
|
33 |
+
for dir_path in other_spec_dir_path.split(','):
|
34 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
35 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
36 |
+
# import ipdb
|
37 |
+
# ipdb.set_trace()
|
38 |
+
self.df_other = pd.concat(df_list,ignore_index=True)
|
39 |
+
self.df_other.reset_index(inplace=True)
|
40 |
+
|
41 |
+
if split == 'train':
|
42 |
+
self.dataset = self.df_main.iloc[100:]
|
43 |
+
elif split == 'valid' or split == 'val':
|
44 |
+
self.dataset = self.df_main.iloc[:100]
|
45 |
+
elif split == 'test':
|
46 |
+
self.df_main = self.add_name_num(self.df_main)
|
47 |
+
self.dataset = self.df_main
|
48 |
+
else:
|
49 |
+
raise ValueError(f'Unknown split {split}')
|
50 |
+
self.dataset.reset_index(inplace=True)
|
51 |
+
print('dataset len:', len(self.dataset),"drop_rate",self.drop)
|
52 |
+
|
53 |
+
def add_name_num(self,df):
|
54 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
55 |
+
name_count_dict = {}
|
56 |
+
change = []
|
57 |
+
for t in df.itertuples():
|
58 |
+
name = getattr(t,'name')
|
59 |
+
if name in name_count_dict:
|
60 |
+
name_count_dict[name] += 1
|
61 |
+
else:
|
62 |
+
name_count_dict[name] = 0
|
63 |
+
change.append((t[0],name_count_dict[name]))
|
64 |
+
for t in change:
|
65 |
+
df.loc[t[0],'name'] = str(df.loc[t[0],'name']) + f'_{t[1]}'
|
66 |
+
return df
|
67 |
+
|
68 |
+
def ordered_indices(self):
|
69 |
+
index2dur = self.dataset[['duration']].sort_values(by='duration')
|
70 |
+
index2dur_other = self.df_other[['duration']].sort_values(by='duration')
|
71 |
+
other_indices = list(index2dur_other.index)
|
72 |
+
offset = len(self.dataset)
|
73 |
+
other_indices = [x + offset for x in other_indices]
|
74 |
+
return list(index2dur.index),other_indices
|
75 |
+
# return list(index2dur.index)
|
76 |
+
|
77 |
+
def collater(self,inputs):
|
78 |
+
to_dict = {}
|
79 |
+
for l in inputs:
|
80 |
+
for k,v in l.items():
|
81 |
+
if k in to_dict:
|
82 |
+
to_dict[k].append(v)
|
83 |
+
else:
|
84 |
+
to_dict[k] = [v]
|
85 |
+
|
86 |
+
if self.collate_mode == 'pad':
|
87 |
+
to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
88 |
+
elif self.collate_mode == 'tile':
|
89 |
+
to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
90 |
+
else:
|
91 |
+
raise NotImplementedError
|
92 |
+
to_dict['caption'] = {'ori_caption':[c['ori_caption'] for c in to_dict['caption']],
|
93 |
+
'struct_caption':[c['struct_caption'] for c in to_dict['caption']]}
|
94 |
+
|
95 |
+
return to_dict
|
96 |
+
|
97 |
+
def __getitem__(self, idx):
|
98 |
+
if idx < len(self.dataset):
|
99 |
+
data = self.dataset.iloc[idx]
|
100 |
+
# p = np.random.uniform(0,1)
|
101 |
+
# if p > self.drop:
|
102 |
+
ori_caption = data['ori_cap']
|
103 |
+
struct_caption = data['caption']
|
104 |
+
# else:
|
105 |
+
# ori_caption = ""
|
106 |
+
# struct_caption = ""
|
107 |
+
else:
|
108 |
+
data = self.df_other.iloc[idx-len(self.dataset)]
|
109 |
+
# p = np.random.uniform(0,1)
|
110 |
+
# if p > self.drop:
|
111 |
+
ori_caption = data['caption']
|
112 |
+
struct_caption = f'<{ori_caption}& all>'
|
113 |
+
# else:
|
114 |
+
# ori_caption = ""
|
115 |
+
# struct_caption = ""
|
116 |
+
item = {}
|
117 |
+
try:
|
118 |
+
spec = np.load(data['mel_path']) # mel spec [80, T]
|
119 |
+
if spec.shape[1] > self.max_batch_len:
|
120 |
+
spec = spec[:,:self.max_batch_len]
|
121 |
+
except:
|
122 |
+
mel_path = data['mel_path']
|
123 |
+
print(f'corrupted:{mel_path}')
|
124 |
+
spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
|
125 |
+
|
126 |
+
item['image'] = spec
|
127 |
+
item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
|
128 |
+
if self.split == 'test':
|
129 |
+
item['f_name'] = data['name']
|
130 |
+
return item
|
131 |
+
|
132 |
+
def __len__(self):
|
133 |
+
return len(self.dataset) + len(self.df_other)
|
134 |
+
# return len(self.dataset)
|
135 |
+
|
136 |
+
|
137 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
138 |
+
def __init__(self, specs_dataset_cfg):
|
139 |
+
super().__init__('train', **specs_dataset_cfg)
|
140 |
+
|
141 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
142 |
+
def __init__(self, specs_dataset_cfg):
|
143 |
+
super().__init__('valid', **specs_dataset_cfg)
|
144 |
+
|
145 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
146 |
+
def __init__(self, specs_dataset_cfg):
|
147 |
+
super().__init__('test', **specs_dataset_cfg)
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到一个batch中以避免过长的pad
|
152 |
+
def __init__(self, main_indices,other_indices,batch_size, num_replicas: Optional[int] = None,
|
153 |
+
# def __init__(self, main_indices,batch_size, num_replicas: Optional[int] = None,
|
154 |
+
rank: Optional[int] = None, shuffle: bool = True,
|
155 |
+
seed: int = 0, drop_last: bool = False) -> None:
|
156 |
+
if num_replicas is None:
|
157 |
+
if not dist.is_initialized():
|
158 |
+
# raise RuntimeError("Requires distributed package to be available")
|
159 |
+
print("Not in distributed mode")
|
160 |
+
num_replicas = 1
|
161 |
+
else:
|
162 |
+
num_replicas = dist.get_world_size()
|
163 |
+
if rank is None:
|
164 |
+
if not dist.is_initialized():
|
165 |
+
# raise RuntimeError("Requires distributed package to be available")
|
166 |
+
rank = 0
|
167 |
+
else:
|
168 |
+
rank = dist.get_rank()
|
169 |
+
if rank >= num_replicas or rank < 0:
|
170 |
+
raise ValueError(
|
171 |
+
"Invalid rank {}, rank should be in the interval"
|
172 |
+
" [0, {}]".format(rank, num_replicas - 1))
|
173 |
+
self.main_indices = main_indices
|
174 |
+
self.other_indices = other_indices
|
175 |
+
self.max_index = max(self.other_indices)
|
176 |
+
self.num_replicas = num_replicas
|
177 |
+
self.rank = rank
|
178 |
+
self.epoch = 0
|
179 |
+
self.drop_last = drop_last
|
180 |
+
self.batch_size = batch_size
|
181 |
+
self.shuffle = shuffle
|
182 |
+
self.batches = self.build_batches()
|
183 |
+
self.seed = seed
|
184 |
+
|
185 |
+
def set_epoch(self,epoch):
|
186 |
+
# print("!!!!!!!!!!!set epoch is called!!!!!!!!!!!!!!")
|
187 |
+
self.epoch = epoch
|
188 |
+
if self.shuffle:
|
189 |
+
np.random.seed(self.seed+self.epoch)
|
190 |
+
self.batches = self.build_batches()
|
191 |
+
|
192 |
+
def build_batches(self):
|
193 |
+
batches,batch = [],[]
|
194 |
+
for index in self.main_indices:
|
195 |
+
batch.append(index)
|
196 |
+
if len(batch) == self.batch_size:
|
197 |
+
batches.append(batch)
|
198 |
+
batch = []
|
199 |
+
if not self.drop_last and len(batch) > 0:
|
200 |
+
batches.append(batch)
|
201 |
+
selected_others = np.random.choice(len(self.other_indices),len(batches),replace=False)
|
202 |
+
for index in selected_others:
|
203 |
+
if index + self.batch_size > len(self.other_indices):
|
204 |
+
index = len(self.other_indices) - self.batch_size
|
205 |
+
batch = [self.other_indices[index + i] for i in range(self.batch_size)]
|
206 |
+
batches.append(batch)
|
207 |
+
self.batches = batches
|
208 |
+
if self.shuffle:
|
209 |
+
self.batches = np.random.permutation(self.batches)
|
210 |
+
if self.rank == 0:
|
211 |
+
print(f"rank: {self.rank}, batches_num {len(self.batches)}")
|
212 |
+
|
213 |
+
if self.drop_last and len(self.batches) % self.num_replicas != 0:
|
214 |
+
self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
|
215 |
+
if len(self.batches) >= self.num_replicas:
|
216 |
+
self.batches = self.batches[self.rank::self.num_replicas]
|
217 |
+
else: # may happen in sanity checking
|
218 |
+
self.batches = [self.batches[0]]
|
219 |
+
if self.rank == 0:
|
220 |
+
print(f"after split batches_num {len(self.batches)}")
|
221 |
+
|
222 |
+
return self.batches
|
223 |
+
|
224 |
+
def __iter__(self) -> Iterator[List[int]]:
|
225 |
+
print(f"len(self.batches):{len(self.batches)}")
|
226 |
+
for batch in self.batches:
|
227 |
+
yield batch
|
228 |
+
|
229 |
+
def __len__(self) -> int:
|
230 |
+
return len(self.batches)
|
ldm/data/preprocess/NAT_mel.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.utils.data
|
4 |
+
from librosa.filters import mel as librosa_mel_fn
|
5 |
+
from scipy.io.wavfile import read
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
MAX_WAV_VALUE = 32768.0
|
10 |
+
|
11 |
+
|
12 |
+
def load_wav(full_path):
|
13 |
+
sampling_rate, data = read(full_path)
|
14 |
+
return data, sampling_rate
|
15 |
+
|
16 |
+
|
17 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
18 |
+
return np.log10(np.clip(x, a_min=clip_val, a_max=None) * C)
|
19 |
+
|
20 |
+
|
21 |
+
def dynamic_range_decompression(x, C=1):
|
22 |
+
return np.exp(x) / C
|
23 |
+
|
24 |
+
|
25 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
26 |
+
return torch.log10(torch.clamp(x, min=clip_val) * C)
|
27 |
+
|
28 |
+
|
29 |
+
def dynamic_range_decompression_torch(x, C=1):
|
30 |
+
return torch.exp(x) / C
|
31 |
+
|
32 |
+
|
33 |
+
def spectral_normalize_torch(magnitudes):
|
34 |
+
output = dynamic_range_compression_torch(magnitudes)
|
35 |
+
return output
|
36 |
+
|
37 |
+
|
38 |
+
def spectral_de_normalize_torch(magnitudes):
|
39 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
40 |
+
return output
|
41 |
+
|
42 |
+
class MelNet(nn.Module):
|
43 |
+
def __init__(self,hparams,device='cpu') -> None:
|
44 |
+
super().__init__()
|
45 |
+
self.n_fft = hparams['fft_size']
|
46 |
+
self.num_mels = hparams['audio_num_mel_bins']
|
47 |
+
self.sampling_rate = hparams['audio_sample_rate']
|
48 |
+
self.hop_size = hparams['hop_size']
|
49 |
+
self.win_size = hparams['win_size']
|
50 |
+
self.fmin = hparams['fmin']
|
51 |
+
self.fmax = hparams['fmax']
|
52 |
+
self.device = device
|
53 |
+
|
54 |
+
mel = librosa_mel_fn(self.sampling_rate, self.n_fft, self.num_mels, self.fmin, self.fmax)
|
55 |
+
self.mel_basis = torch.from_numpy(mel).float().to(self.device)
|
56 |
+
self.hann_window = torch.hann_window(self.win_size).to(self.device)
|
57 |
+
|
58 |
+
def to(self,device,**kwagrs):
|
59 |
+
super().to(device=device,**kwagrs)
|
60 |
+
self.mel_basis = self.mel_basis.to(device)
|
61 |
+
self.hann_window = self.hann_window.to(device)
|
62 |
+
self.device = device
|
63 |
+
|
64 |
+
def forward(self,y,center=False, complex=False):
|
65 |
+
if isinstance(y,np.ndarray):
|
66 |
+
y = torch.FloatTensor(y)
|
67 |
+
if len(y.shape) == 1:
|
68 |
+
y = y.unsqueeze(0)
|
69 |
+
y = y.clamp(min=-1., max=1.).to(self.device)
|
70 |
+
|
71 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), [int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2)],
|
72 |
+
mode='reflect')
|
73 |
+
y = y.squeeze(1)
|
74 |
+
|
75 |
+
spec = torch.stft(y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window,
|
76 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True,return_complex=complex)
|
77 |
+
|
78 |
+
if not complex:
|
79 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
80 |
+
spec = torch.matmul(self.mel_basis, spec)
|
81 |
+
spec = spectral_normalize_torch(spec)
|
82 |
+
else:
|
83 |
+
B, C, T, _ = spec.shape
|
84 |
+
spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
|
85 |
+
return spec
|
86 |
+
|
87 |
+
## below can be used in one gpu, but not ddp
|
88 |
+
mel_basis = {}
|
89 |
+
hann_window = {}
|
90 |
+
|
91 |
+
|
92 |
+
def mel_spectrogram(y, hparams, center=False, complex=False): # y should be a tensor with shape (b,wav_len)
|
93 |
+
# hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
|
94 |
+
# win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
|
95 |
+
# fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
96 |
+
# fmax: 10000 # To be increased/reduced depending on data.
|
97 |
+
# fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
|
98 |
+
# n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
|
99 |
+
n_fft = hparams['fft_size']
|
100 |
+
num_mels = hparams['audio_num_mel_bins']
|
101 |
+
sampling_rate = hparams['audio_sample_rate']
|
102 |
+
hop_size = hparams['hop_size']
|
103 |
+
win_size = hparams['win_size']
|
104 |
+
fmin = hparams['fmin']
|
105 |
+
fmax = hparams['fmax']
|
106 |
+
if isinstance(y,np.ndarray):
|
107 |
+
y = torch.FloatTensor(y)
|
108 |
+
if len(y.shape) == 1:
|
109 |
+
y = y.unsqueeze(0)
|
110 |
+
y = y.clamp(min=-1., max=1.)
|
111 |
+
global mel_basis, hann_window
|
112 |
+
if fmax not in mel_basis:
|
113 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
114 |
+
mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
115 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
116 |
+
|
117 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), [int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)],
|
118 |
+
mode='reflect')
|
119 |
+
y = y.squeeze(1)
|
120 |
+
|
121 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
122 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True,return_complex=complex)
|
123 |
+
|
124 |
+
if not complex:
|
125 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
126 |
+
spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
|
127 |
+
spec = spectral_normalize_torch(spec)
|
128 |
+
else:
|
129 |
+
B, C, T, _ = spec.shape
|
130 |
+
spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
|
131 |
+
return spec
|
ldm/data/preprocess/__pycache__/NAT_mel.cpython-38.pyc
ADDED
Binary file (4.25 kB). View file
|
|
ldm/data/preprocess/__pycache__/NAT_mel.cpython-39.pyc
ADDED
Binary file (4.23 kB). View file
|
|
ldm/data/preprocess/add_duration.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import audioread
|
3 |
+
from tqdm import tqdm
|
4 |
+
from tqdm.contrib.concurrent import process_map
|
5 |
+
|
6 |
+
def map_duration(tsv_withdur,tsv_toadd):# tsv_withdur 和 tsv_toadd 'name'列相同且tsv_withdur有duration信息,目标是给tsv_toadd的相同行加上duration信息。
|
7 |
+
df1 = pd.read_csv(tsv_withdur,sep='\t')
|
8 |
+
df2 = pd.read_csv(tsv_toadd,sep='\t')
|
9 |
+
|
10 |
+
df = df2.merge(df1,on=['name'],suffixes=['','_y'])
|
11 |
+
dropset = list(set(df.columns) - set(df1.columns))
|
12 |
+
df = df.drop(dropset,axis=1)
|
13 |
+
df.to_csv(tsv_toadd,sep='\t',index=False)
|
14 |
+
return df
|
15 |
+
|
16 |
+
def add_duration(args):
|
17 |
+
index,audiopath = args
|
18 |
+
try:
|
19 |
+
with audioread.audio_open(audiopath) as f:
|
20 |
+
totalsec = f.duration
|
21 |
+
except:
|
22 |
+
totalsec = -1
|
23 |
+
return (index,totalsec)
|
24 |
+
|
25 |
+
def add_dur2tsv(tsv_path,save_path):
|
26 |
+
df = pd.read_csv(tsv_path,sep='\t')
|
27 |
+
item_list = []
|
28 |
+
for item in tqdm(df.itertuples()):
|
29 |
+
item_list.append((item[0],getattr(item,'audio_path')))
|
30 |
+
|
31 |
+
r = process_map(add_duration,item_list,max_workers=16,chunksize=32)
|
32 |
+
index2dur = {}
|
33 |
+
for index,dur in r:
|
34 |
+
if dur == -1:
|
35 |
+
bad_wav = df.loc[index,'audio_path']
|
36 |
+
print(f'bad wav:{bad_wav}')
|
37 |
+
index2dur[index] = dur
|
38 |
+
|
39 |
+
df['duration'] = df.index.map(index2dur)
|
40 |
+
df.to_csv(save_path,sep='\t',index=False)
|
41 |
+
|
42 |
+
if __name__ == '__main__':
|
43 |
+
add_dur2tsv('/root/autodl-tmp/liuhuadai/AudioLCM/now.tsv','/root/autodl-tmp/liuhuadai/AudioLCM/now_duration.tsv')
|
44 |
+
#map_duration(tsv_withdur='tsv_maker/filter_audioset.tsv',
|
45 |
+
# tsv_toadd='MAA1 Dataset tsvs/V3/refilter_audioset.tsv')
|
ldm/data/preprocess/mel_spec.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ldm.data.preprocess.NAT_mel import MelNet
|
2 |
+
import os
|
3 |
+
from tqdm import tqdm
|
4 |
+
from glob import glob
|
5 |
+
import math
|
6 |
+
import pandas as pd
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
import audioread
|
10 |
+
from tqdm.contrib.concurrent import process_map
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torchaudio
|
14 |
+
import numpy as np
|
15 |
+
from torch.distributed import init_process_group
|
16 |
+
from torch.utils.data import Dataset,DataLoader,DistributedSampler
|
17 |
+
import torch.multiprocessing as mp
|
18 |
+
from argparse import Namespace
|
19 |
+
from multiprocessing import Pool
|
20 |
+
import json
|
21 |
+
|
22 |
+
|
23 |
+
class tsv_dataset(Dataset):
|
24 |
+
def __init__(self,tsv_path,sr,mode='none',hop_size = None,target_mel_length = None) -> None:
|
25 |
+
super().__init__()
|
26 |
+
if os.path.isdir(tsv_path):
|
27 |
+
files = glob(os.path.join(tsv_path,'*.tsv'))
|
28 |
+
df = pd.concat([pd.read_csv(file,sep='\t') for file in files])
|
29 |
+
else:
|
30 |
+
df = pd.read_csv(tsv_path,sep='\t')
|
31 |
+
self.audio_paths = []
|
32 |
+
self.sr = sr
|
33 |
+
self.mode = mode
|
34 |
+
self.target_mel_length = target_mel_length
|
35 |
+
self.hop_size = hop_size
|
36 |
+
for t in tqdm(df.itertuples()):
|
37 |
+
self.audio_paths.append(getattr(t,'audio_path'))
|
38 |
+
|
39 |
+
def __len__(self):
|
40 |
+
return len(self.audio_paths)
|
41 |
+
|
42 |
+
def pad_wav(self,wav):
|
43 |
+
# wav should be in shape(1,wav_len)
|
44 |
+
wav_length = wav.shape[-1]
|
45 |
+
assert wav_length > 100, "wav is too short, %s" % wav_length
|
46 |
+
segment_length = (self.target_mel_length + 1) * self.hop_size # final mel will crop the last mel, mel = mel[:,:-1]
|
47 |
+
if segment_length is None or wav_length == segment_length:
|
48 |
+
return wav
|
49 |
+
elif wav_length > segment_length:
|
50 |
+
return wav[:,:segment_length]
|
51 |
+
elif wav_length < segment_length:
|
52 |
+
temp_wav = torch.zeros((1, segment_length),dtype=torch.float32)
|
53 |
+
temp_wav[:, :wav_length] = wav
|
54 |
+
return temp_wav
|
55 |
+
|
56 |
+
|
57 |
+
def __getitem__(self, index):
|
58 |
+
audio_path = self.audio_paths[index]
|
59 |
+
wav, orisr = torchaudio.load(audio_path)
|
60 |
+
if wav.shape[0] != 1: # stereo to mono (2,wav_len) -> (1,wav_len)
|
61 |
+
wav = wav.mean(0,keepdim=True)
|
62 |
+
wav = torchaudio.functional.resample(wav, orig_freq=orisr, new_freq=self.sr)
|
63 |
+
if self.mode == 'pad':
|
64 |
+
assert self.target_mel_length is not None
|
65 |
+
wav = self.pad_wav(wav)
|
66 |
+
return audio_path,wav
|
67 |
+
|
68 |
+
def process_audio_by_tsv(rank,args):
|
69 |
+
if args.num_gpus > 1:
|
70 |
+
init_process_group(backend=args.dist_config['dist_backend'], init_method=args.dist_config['dist_url'],
|
71 |
+
world_size=args.dist_config['world_size'] * args.num_gpus, rank=rank)
|
72 |
+
|
73 |
+
sr = args.audio_sample_rate
|
74 |
+
dataset = tsv_dataset(args.tsv_path,sr = sr,mode=args.mode,hop_size=args.hop_size,target_mel_length=args.batch_max_length)
|
75 |
+
sampler = DistributedSampler(dataset,shuffle=False) if args.num_gpus > 1 else None
|
76 |
+
# batch_size must == 1,since wav_len is not equal
|
77 |
+
loader = DataLoader(dataset, sampler=sampler,batch_size=1, num_workers=16,drop_last=False)
|
78 |
+
|
79 |
+
device = torch.device('cuda:{:d}'.format(rank))
|
80 |
+
|
81 |
+
mel_net = MelNet(args.__dict__)
|
82 |
+
mel_net.to(device)
|
83 |
+
# if args.num_gpus > 1: # RuntimeError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
|
84 |
+
# mel_net = DistributedDataParallel(mel_net, device_ids=[rank]).to(device)
|
85 |
+
|
86 |
+
loader = tqdm(loader) if rank == 0 else loader
|
87 |
+
for batch in loader:
|
88 |
+
audio_paths,wavs = batch
|
89 |
+
wavs = wavs.to(device)
|
90 |
+
if args.save_resample:
|
91 |
+
for audio_path,wav in zip(audio_paths,wavs):
|
92 |
+
psplits = audio_path.split('/')
|
93 |
+
root,wav_name = psplits[0],psplits[-1]
|
94 |
+
# save resample
|
95 |
+
resample_root,resample_name = root+f'_{sr}',wav_name[:-4]+'_audio.npy'
|
96 |
+
resample_dir_name = os.path.join(resample_root,*psplits[1:-1])
|
97 |
+
resample_path = os.path.join(resample_dir_name,resample_name)
|
98 |
+
os.makedirs(resample_dir_name,exist_ok=True)
|
99 |
+
np.save(resample_path,wav.cpu().numpy().squeeze(0))
|
100 |
+
|
101 |
+
if args.save_mel:
|
102 |
+
mode = args.mode
|
103 |
+
batch_max_length = args.batch_max_length
|
104 |
+
|
105 |
+
for audio_path,wav in zip(audio_paths,wavs):
|
106 |
+
psplits = audio_path.split('/')
|
107 |
+
root,wav_name = psplits[0],psplits[-1]
|
108 |
+
mel_root,mel_name = root+f'_mel{mode}{sr}nfft{args.fft_size}',wav_name[:-4]+'_mel.npy'
|
109 |
+
mel_dir_name = os.path.join(mel_root,*psplits[1:-1])
|
110 |
+
mel_path = os.path.join(mel_dir_name,mel_name)
|
111 |
+
if not os.path.exists(mel_path):
|
112 |
+
mel_spec = mel_net(wav).cpu().numpy().squeeze(0) # (mel_bins,mel_len)
|
113 |
+
if mel_spec.shape[1] <= batch_max_length:
|
114 |
+
if mode == 'tile': # pad is done in dataset as pad wav
|
115 |
+
n_repeat = math.ceil((batch_max_length + 1) / mel_spec.shape[1])
|
116 |
+
mel_spec = np.tile(mel_spec,reps=(1,n_repeat))
|
117 |
+
elif mode == 'none' or mode == 'pad':
|
118 |
+
pass
|
119 |
+
else:
|
120 |
+
raise ValueError(f'mode:{mode} is not supported')
|
121 |
+
mel_spec = mel_spec[:,:batch_max_length]
|
122 |
+
os.makedirs(mel_dir_name,exist_ok=True)
|
123 |
+
np.save(mel_path,mel_spec)
|
124 |
+
|
125 |
+
|
126 |
+
def split_list(i_list,num):
|
127 |
+
each_num = math.ceil(i_list / num)
|
128 |
+
result = []
|
129 |
+
for i in range(num):
|
130 |
+
s = each_num * i
|
131 |
+
e = (each_num * (i+1))
|
132 |
+
result.append(i_list[s:e])
|
133 |
+
return result
|
134 |
+
|
135 |
+
|
136 |
+
def drop_bad_wav(item):
|
137 |
+
index,path = item
|
138 |
+
try:
|
139 |
+
with audioread.audio_open(path) as f:
|
140 |
+
totalsec = f.duration
|
141 |
+
if totalsec < 0.1:
|
142 |
+
return index # index
|
143 |
+
except:
|
144 |
+
print(f"corrupted wav:{path}")
|
145 |
+
return index
|
146 |
+
return False
|
147 |
+
|
148 |
+
def drop_bad_wavs(tsv_path):# 'audioset.csv'
|
149 |
+
df = pd.read_csv(tsv_path,sep='\t')
|
150 |
+
item_list = []
|
151 |
+
for item in tqdm(df.itertuples()):
|
152 |
+
item_list.append((item[0],getattr(item,'audio_path')))
|
153 |
+
|
154 |
+
r = process_map(drop_bad_wav,item_list,max_workers=16,chunksize=16)
|
155 |
+
bad_indices = list(filter(lambda x:x!= False,r))
|
156 |
+
|
157 |
+
print(bad_indices)
|
158 |
+
with open('bad_wavs.json','w') as f:
|
159 |
+
x = [item_list[i] for i in bad_indices]
|
160 |
+
json.dump(x,f)
|
161 |
+
df = df.drop(bad_indices,axis=0)
|
162 |
+
df.to_csv(tsv_path,sep='\t',index=False)
|
163 |
+
|
164 |
+
if __name__ == '__main__':
|
165 |
+
logging.basicConfig(filename='example.log', level=logging.INFO,
|
166 |
+
format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p')
|
167 |
+
tsv_path = './musiccap.tsv'
|
168 |
+
if os.path.isdir(tsv_path):
|
169 |
+
files = glob(os.path.join(tsv_path,'*.tsv'))
|
170 |
+
for file in files:
|
171 |
+
drop_bad_wavs(file)
|
172 |
+
else:
|
173 |
+
drop_bad_wavs(tsv_path)
|
174 |
+
num_gpus = 1
|
175 |
+
args = {
|
176 |
+
'audio_sample_rate': 16000,
|
177 |
+
'audio_num_mel_bins':80,
|
178 |
+
'fft_size': 1024,# 4000:512 ,16000:1024,
|
179 |
+
'win_size': 1024,
|
180 |
+
'hop_size': 256,
|
181 |
+
'fmin': 0,
|
182 |
+
'fmax': 8000,
|
183 |
+
'batch_max_length': 1560, # 4000:312 (nfft = 512,hoplen=128,mellen = 313), 16000:624 , 22050:848 #
|
184 |
+
'tsv_path': tsv_path,
|
185 |
+
'num_gpus': num_gpus,
|
186 |
+
'mode': 'none',
|
187 |
+
'save_resample':False,
|
188 |
+
'save_mel' :True
|
189 |
+
}
|
190 |
+
args = Namespace(**args)
|
191 |
+
args.dist_config = {
|
192 |
+
"dist_backend": "nccl",
|
193 |
+
"dist_url": "tcp://localhost:54189",
|
194 |
+
"world_size": 1
|
195 |
+
}
|
196 |
+
if args.num_gpus>1:
|
197 |
+
mp.spawn(process_audio_by_tsv,nprocs=args.num_gpus,args=(args,))
|
198 |
+
else:
|
199 |
+
process_audio_by_tsv(0,args=args)
|
200 |
+
print("done")
|
201 |
+
|
ldm/data/test.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from typing import TypeVar, Optional, Iterator
|
5 |
+
import logging
|
6 |
+
import pandas as pd
|
7 |
+
from ldm.data.joinaudiodataset_anylen import *
|
8 |
+
import glob
|
9 |
+
logger = logging.getLogger(f'main.{__name__}')
|
10 |
+
|
11 |
+
sys.path.insert(0, '.') # nopep8
|
12 |
+
|
13 |
+
class JoinManifestSpecs(torch.utils.data.Dataset):
|
14 |
+
def __init__(self, split, main_spec_dir_path,other_spec_dir_path, mel_num=80,mode='pad', spec_crop_len=1248,pad_value=-5,drop=0,**kwargs):
|
15 |
+
super().__init__()
|
16 |
+
self.split = split
|
17 |
+
self.max_batch_len = spec_crop_len
|
18 |
+
self.min_batch_len = 64
|
19 |
+
self.min_factor = 4
|
20 |
+
self.mel_num = mel_num
|
21 |
+
self.drop = drop
|
22 |
+
self.pad_value = pad_value
|
23 |
+
assert mode in ['pad','tile']
|
24 |
+
self.collate_mode = mode
|
25 |
+
manifest_files = []
|
26 |
+
for dir_path in main_spec_dir_path.split(','):
|
27 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
28 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
29 |
+
self.df_main = pd.concat(df_list,ignore_index=True)
|
30 |
+
|
31 |
+
manifest_files = []
|
32 |
+
for dir_path in other_spec_dir_path.split(','):
|
33 |
+
manifest_files += glob.glob(f'{dir_path}/*.tsv')
|
34 |
+
df_list = [pd.read_csv(manifest,sep='\t') for manifest in manifest_files]
|
35 |
+
self.df_other = pd.concat(df_list,ignore_index=True)
|
36 |
+
self.df_other.reset_index(inplace=True)
|
37 |
+
|
38 |
+
if split == 'train':
|
39 |
+
self.dataset = self.df_main.iloc[100:]
|
40 |
+
elif split == 'valid' or split == 'val':
|
41 |
+
self.dataset = self.df_main.iloc[:100]
|
42 |
+
elif split == 'test':
|
43 |
+
self.df_main = self.add_name_num(self.df_main)
|
44 |
+
self.dataset = self.df_main
|
45 |
+
else:
|
46 |
+
raise ValueError(f'Unknown split {split}')
|
47 |
+
self.dataset.reset_index(inplace=True)
|
48 |
+
print('dataset len:', len(self.dataset),"drop_rate",self.drop)
|
49 |
+
|
50 |
+
def add_name_num(self,df):
|
51 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
52 |
+
name_count_dict = {}
|
53 |
+
change = []
|
54 |
+
for t in df.itertuples():
|
55 |
+
name = getattr(t,'name')
|
56 |
+
if name in name_count_dict:
|
57 |
+
name_count_dict[name] += 1
|
58 |
+
else:
|
59 |
+
name_count_dict[name] = 0
|
60 |
+
change.append((t[0],name_count_dict[name]))
|
61 |
+
for t in change:
|
62 |
+
df.loc[t[0],'name'] = str(df.loc[t[0],'name']) + f'_{t[1]}'
|
63 |
+
return df
|
64 |
+
|
65 |
+
def ordered_indices(self):
|
66 |
+
index2dur = self.dataset[['duration']].sort_values(by='duration')
|
67 |
+
index2dur_other = self.df_other[['duration']].sort_values(by='duration')
|
68 |
+
other_indices = list(index2dur_other.index)
|
69 |
+
offset = len(self.dataset)
|
70 |
+
other_indices = [x + offset for x in other_indices]
|
71 |
+
return list(index2dur.index),other_indices
|
72 |
+
|
73 |
+
def collater(self,inputs):
|
74 |
+
to_dict = {}
|
75 |
+
for l in inputs:
|
76 |
+
for k,v in l.items():
|
77 |
+
if k in to_dict:
|
78 |
+
to_dict[k].append(v)
|
79 |
+
else:
|
80 |
+
to_dict[k] = [v]
|
81 |
+
|
82 |
+
if self.collate_mode == 'pad':
|
83 |
+
to_dict['image'] = collate_1d_or_2d(to_dict['image'],pad_idx=self.pad_value,min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
84 |
+
elif self.collate_mode == 'tile':
|
85 |
+
to_dict['image'] = collate_1d_or_2d_tile(to_dict['image'],min_len = self.min_batch_len,max_len=self.max_batch_len,min_factor=self.min_factor)
|
86 |
+
else:
|
87 |
+
raise NotImplementedError
|
88 |
+
to_dict['caption'] = {'ori_caption':[c['ori_caption'] for c in to_dict['caption']],
|
89 |
+
'struct_caption':[c['struct_caption'] for c in to_dict['caption']]}
|
90 |
+
|
91 |
+
return to_dict
|
92 |
+
|
93 |
+
def __getitem__(self, idx):
|
94 |
+
if idx < len(self.dataset):
|
95 |
+
data = self.dataset.iloc[idx]
|
96 |
+
p = np.random.uniform(0,1)
|
97 |
+
if p > self.drop:
|
98 |
+
ori_caption = data['ori_cap']
|
99 |
+
struct_caption = data['caption']
|
100 |
+
else:
|
101 |
+
ori_caption = ""
|
102 |
+
struct_caption = ""
|
103 |
+
else:
|
104 |
+
data = self.df_other.iloc[idx-len(self.dataset)]
|
105 |
+
p = np.random.uniform(0,1)
|
106 |
+
if p > self.drop:
|
107 |
+
ori_caption = data['caption']
|
108 |
+
struct_caption = f'<{ori_caption}& all>'
|
109 |
+
else:
|
110 |
+
ori_caption = ""
|
111 |
+
struct_caption = ""
|
112 |
+
item = {}
|
113 |
+
try:
|
114 |
+
spec = np.load(data['mel_path']) # mel spec [80, T]
|
115 |
+
if spec.shape[1] > self.max_batch_len:
|
116 |
+
spec = spec[:,:self.max_batch_len]
|
117 |
+
except:
|
118 |
+
mel_path = data['mel_path']
|
119 |
+
print(f'corrupted:{mel_path}')
|
120 |
+
spec = np.ones((self.mel_num,self.min_batch_len)).astype(np.float32)*self.pad_value
|
121 |
+
|
122 |
+
item['image'] = spec
|
123 |
+
item["caption"] = {"ori_caption":ori_caption,"struct_caption":struct_caption}
|
124 |
+
if self.split == 'test':
|
125 |
+
item['f_name'] = data['name']
|
126 |
+
return item
|
127 |
+
|
128 |
+
def __len__(self):
|
129 |
+
return len(self.dataset) + len(self.df_other)
|
130 |
+
|
131 |
+
|
132 |
+
class JoinSpecsTrain(JoinManifestSpecs):
|
133 |
+
def __init__(self, specs_dataset_cfg):
|
134 |
+
super().__init__('train', **specs_dataset_cfg)
|
135 |
+
|
136 |
+
class JoinSpecsValidation(JoinManifestSpecs):
|
137 |
+
def __init__(self, specs_dataset_cfg):
|
138 |
+
super().__init__('valid', **specs_dataset_cfg)
|
139 |
+
|
140 |
+
class JoinSpecsTest(JoinManifestSpecs):
|
141 |
+
def __init__(self, specs_dataset_cfg):
|
142 |
+
super().__init__('test', **specs_dataset_cfg)
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
class DDPIndexBatchSampler(Sampler):# 让长度相似的音频的indices合到一个batch中以避免过长的pad
|
147 |
+
def __init__(self, main_indices,other_indices,batch_size, num_replicas: Optional[int] = None,
|
148 |
+
rank: Optional[int] = None, shuffle: bool = True,
|
149 |
+
seed: int = 0, drop_last: bool = False) -> None:
|
150 |
+
if num_replicas is None:
|
151 |
+
if not dist.is_initialized():
|
152 |
+
# raise RuntimeError("Requires distributed package to be available")
|
153 |
+
print("Not in distributed mode")
|
154 |
+
num_replicas = 1
|
155 |
+
else:
|
156 |
+
num_replicas = dist.get_world_size()
|
157 |
+
if rank is None:
|
158 |
+
if not dist.is_initialized():
|
159 |
+
# raise RuntimeError("Requires distributed package to be available")
|
160 |
+
rank = 0
|
161 |
+
else:
|
162 |
+
rank = dist.get_rank()
|
163 |
+
if rank >= num_replicas or rank < 0:
|
164 |
+
raise ValueError(
|
165 |
+
"Invalid rank {}, rank should be in the interval"
|
166 |
+
" [0, {}]".format(rank, num_replicas - 1))
|
167 |
+
self.main_indices = main_indices
|
168 |
+
self.other_indices = other_indices
|
169 |
+
self.max_index = max(self.other_indices)
|
170 |
+
self.num_replicas = num_replicas
|
171 |
+
self.rank = rank
|
172 |
+
self.epoch = 0
|
173 |
+
self.drop_last = drop_last
|
174 |
+
self.batch_size = batch_size
|
175 |
+
self.shuffle = shuffle
|
176 |
+
self.batches = self.build_batches()
|
177 |
+
self.seed = seed
|
178 |
+
|
179 |
+
def set_epoch(self,epoch):
|
180 |
+
# print("!!!!!!!!!!!set epoch is called!!!!!!!!!!!!!!")
|
181 |
+
self.epoch = epoch
|
182 |
+
if self.shuffle:
|
183 |
+
np.random.seed(self.seed+self.epoch)
|
184 |
+
self.batches = self.build_batches()
|
185 |
+
|
186 |
+
def build_batches(self):
|
187 |
+
batches,batch = [],[]
|
188 |
+
for index in self.main_indices:
|
189 |
+
batch.append(index)
|
190 |
+
if len(batch) == self.batch_size:
|
191 |
+
batches.append(batch)
|
192 |
+
batch = []
|
193 |
+
if not self.drop_last and len(batch) > 0:
|
194 |
+
batches.append(batch)
|
195 |
+
selected_others = np.random.choice(len(self.other_indices),len(batches),replace=False)
|
196 |
+
for index in selected_others:
|
197 |
+
if index + self.batch_size > len(self.other_indices):
|
198 |
+
index = len(self.other_indices) - self.batch_size
|
199 |
+
batch = [self.other_indices[index + i] for i in range(self.batch_size)]
|
200 |
+
batches.append(batch)
|
201 |
+
self.batches = batches
|
202 |
+
if self.shuffle:
|
203 |
+
self.batches = np.random.permutation(self.batches)
|
204 |
+
if self.rank == 0:
|
205 |
+
print(f"rank: {self.rank}, batches_num {len(self.batches)}")
|
206 |
+
|
207 |
+
if self.drop_last and len(self.batches) % self.num_replicas != 0:
|
208 |
+
self.batches = self.batches[:len(self.batches)//self.num_replicas*self.num_replicas]
|
209 |
+
if len(self.batches) >= self.num_replicas:
|
210 |
+
self.batches = self.batches[self.rank::self.num_replicas]
|
211 |
+
else: # may happen in sanity checking
|
212 |
+
self.batches = [self.batches[0]]
|
213 |
+
if self.rank == 0:
|
214 |
+
print(f"after split batches_num {len(self.batches)}")
|
215 |
+
|
216 |
+
return self.batches
|
217 |
+
|
218 |
+
def __iter__(self) -> Iterator[List[int]]:
|
219 |
+
print(f"len(self.batches):{len(self.batches)}")
|
220 |
+
for batch in self.batches:
|
221 |
+
yield batch
|
222 |
+
|
223 |
+
def __len__(self) -> int:
|
224 |
+
return len(self.batches)
|
ldm/data/tsv_dirs/full_data/V1_new/audiocaps_train_16000.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ldm/data/tsv_dirs/full_data/V2/MACS.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ldm/data/tsv_dirs/full_data/V2/WavText5K.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ldm/data/tsv_dirs/full_data/V2/adobe.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ldm/data/tsv_dirs/full_data/V2/audiostock.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ldm/data/tsv_dirs/full_data/V2/epidemic_sound.tsv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dc67e42c9defa98edfc2c6b23c731fafa4a22307fddfd1fb95ccfc00d0168951
|
3 |
+
size 15062608
|
ldm/data/tsv_dirs/full_data/caps_struct/audiocaps_train_16000_struct2.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ldm/data/tsv_dirs/full_data/clotho.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ldm/data/tsvdataset.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from glob import glob
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
class TSVDataset(Dataset):
|
7 |
+
def __init__(self, tsv_path, spec_crop_len=None):
|
8 |
+
super().__init__()
|
9 |
+
self.batch_max_length = spec_crop_len
|
10 |
+
self.batch_min_length = 50
|
11 |
+
df = pd.read_csv(tsv_path,sep='\t')
|
12 |
+
df = self.add_name_num(df)
|
13 |
+
self.dataset = df
|
14 |
+
print('dataset len:', len(self.dataset))
|
15 |
+
|
16 |
+
def add_name_num(self,df):
|
17 |
+
"""each file may have different caption, we add num to filename to identify each audio-caption pair"""
|
18 |
+
name_count_dict = {}
|
19 |
+
change = []
|
20 |
+
for t in df.itertuples():
|
21 |
+
name = getattr(t,'name')
|
22 |
+
if name in name_count_dict:
|
23 |
+
name_count_dict[name] += 1
|
24 |
+
else:
|
25 |
+
name_count_dict[name] = 0
|
26 |
+
change.append((t[0],name_count_dict[name]))
|
27 |
+
for t in change:
|
28 |
+
df.loc[t[0],'name'] = df.loc[t[0],'name'] + f'_{t[1]}'
|
29 |
+
return df
|
30 |
+
|
31 |
+
|
32 |
+
def __getitem__(self, idx):
|
33 |
+
data = self.dataset.iloc[idx]
|
34 |
+
item = {}
|
35 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
36 |
+
if spec.shape[1] <= self.batch_max_length:
|
37 |
+
spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1]))) # [80, 624]
|
38 |
+
|
39 |
+
item['image'] = spec
|
40 |
+
item["caption"] = data['caption']
|
41 |
+
item["f_name"] = data['name']
|
42 |
+
return item
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return len(self.dataset)
|
46 |
+
|
47 |
+
class TSVDatasetStruct(TSVDataset):
|
48 |
+
def __getitem__(self, idx):
|
49 |
+
data = self.dataset.iloc[idx]
|
50 |
+
item = {}
|
51 |
+
spec = np.load(data['mel_path']) # mel spec [80, 624]
|
52 |
+
if spec.shape[1] <= self.batch_max_length:
|
53 |
+
spec = np.pad(spec, ((0, 0), (0, self.batch_max_length - spec.shape[1]))) # [80, 624]
|
54 |
+
|
55 |
+
item['image'] = spec[:,:self.batch_max_length]
|
56 |
+
item["caption"] = {'ori_caption':data['ori_cap'],'struct_caption':data['caption']}
|
57 |
+
item["f_name"] = data['name']
|
58 |
+
return item
|
59 |
+
|
60 |
+
class TSVDatasetTestFake(TSVDataset):
|
61 |
+
def __init__(self, specs_dataset_cfg):
|
62 |
+
super().__init__(phase='test', **specs_dataset_cfg)
|
63 |
+
self.dataset = [self.dataset[0]]
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
ldm/lr_scheduler.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class LambdaWarmUpCosineScheduler:
|
5 |
+
"""
|
6 |
+
note: use with a base_lr of 1.0
|
7 |
+
"""
|
8 |
+
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
9 |
+
self.lr_warm_up_steps = warm_up_steps
|
10 |
+
self.lr_start = lr_start
|
11 |
+
self.lr_min = lr_min
|
12 |
+
self.lr_max = lr_max
|
13 |
+
self.lr_max_decay_steps = max_decay_steps
|
14 |
+
self.last_lr = 0.
|
15 |
+
self.verbosity_interval = verbosity_interval
|
16 |
+
|
17 |
+
def schedule(self, n, **kwargs):
|
18 |
+
if self.verbosity_interval > 0:
|
19 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
20 |
+
if n < self.lr_warm_up_steps:
|
21 |
+
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
22 |
+
self.last_lr = lr
|
23 |
+
return lr
|
24 |
+
else:
|
25 |
+
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
26 |
+
t = min(t, 1.0)
|
27 |
+
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
28 |
+
1 + np.cos(t * np.pi))
|
29 |
+
self.last_lr = lr
|
30 |
+
return lr
|
31 |
+
|
32 |
+
def __call__(self, n, **kwargs):
|
33 |
+
return self.schedule(n,**kwargs)
|
34 |
+
|
35 |
+
|
36 |
+
class LambdaWarmUpCosineScheduler2:
|
37 |
+
"""
|
38 |
+
supports repeated iterations, configurable via lists
|
39 |
+
note: use with a base_lr of 1.0.
|
40 |
+
"""
|
41 |
+
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
42 |
+
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
43 |
+
self.lr_warm_up_steps = warm_up_steps
|
44 |
+
self.f_start = f_start
|
45 |
+
self.f_min = f_min
|
46 |
+
self.f_max = f_max
|
47 |
+
self.cycle_lengths = cycle_lengths
|
48 |
+
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
49 |
+
self.last_f = 0.
|
50 |
+
self.verbosity_interval = verbosity_interval
|
51 |
+
|
52 |
+
def find_in_interval(self, n):
|
53 |
+
interval = 0
|
54 |
+
for cl in self.cum_cycles[1:]:
|
55 |
+
if n <= cl:
|
56 |
+
return interval
|
57 |
+
interval += 1
|
58 |
+
|
59 |
+
def schedule(self, n, **kwargs):
|
60 |
+
cycle = self.find_in_interval(n)
|
61 |
+
n = n - self.cum_cycles[cycle]
|
62 |
+
if self.verbosity_interval > 0:
|
63 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
64 |
+
f"current cycle {cycle}")
|
65 |
+
if n < self.lr_warm_up_steps[cycle]:
|
66 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
67 |
+
self.last_f = f
|
68 |
+
return f
|
69 |
+
else:
|
70 |
+
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
71 |
+
t = min(t, 1.0)
|
72 |
+
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
73 |
+
1 + np.cos(t * np.pi))
|
74 |
+
self.last_f = f
|
75 |
+
return f
|
76 |
+
|
77 |
+
def __call__(self, n, **kwargs):
|
78 |
+
return self.schedule(n, **kwargs)
|
79 |
+
|
80 |
+
|
81 |
+
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
82 |
+
|
83 |
+
def schedule(self, n, **kwargs):
|
84 |
+
cycle = self.find_in_interval(n)
|
85 |
+
n = n - self.cum_cycles[cycle]
|
86 |
+
if self.verbosity_interval > 0:
|
87 |
+
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
88 |
+
f"current cycle {cycle}")
|
89 |
+
|
90 |
+
if n < self.lr_warm_up_steps[cycle]:
|
91 |
+
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
92 |
+
self.last_f = f
|
93 |
+
return f
|
94 |
+
else:
|
95 |
+
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
96 |
+
self.last_f = f
|
97 |
+
return f
|
98 |
+
|
ldm/models/__pycache__/autoencoder.cpython-37.pyc
ADDED
Binary file (15.6 kB). View file
|
|
ldm/models/__pycache__/autoencoder.cpython-38.pyc
ADDED
Binary file (15.5 kB). View file
|
|
ldm/models/__pycache__/autoencoder.cpython-39.pyc
ADDED
Binary file (14.9 kB). View file
|
|
ldm/models/__pycache__/autoencoder1d.cpython-37.pyc
ADDED
Binary file (13.5 kB). View file
|
|
ldm/models/__pycache__/autoencoder1d.cpython-38.pyc
ADDED
Binary file (13.4 kB). View file
|
|
ldm/models/__pycache__/autoencoder_multi.cpython-38.pyc
ADDED
Binary file (14.8 kB). View file
|
|
ldm/models/autoencoder.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from contextlib import contextmanager
|
6 |
+
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
7 |
+
from packaging import version
|
8 |
+
import numpy as np
|
9 |
+
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
10 |
+
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
11 |
+
from torch.optim.lr_scheduler import LambdaLR
|
12 |
+
from ldm.util import instantiate_from_config
|
13 |
+
from icecream import ic
|
14 |
+
|
15 |
+
class VQModel(pl.LightningModule):
|
16 |
+
def __init__(self,
|
17 |
+
ddconfig,
|
18 |
+
lossconfig,
|
19 |
+
n_embed,
|
20 |
+
embed_dim,
|
21 |
+
ckpt_path=None,
|
22 |
+
ignore_keys=[],
|
23 |
+
image_key="image",
|
24 |
+
colorize_nlabels=None,
|
25 |
+
monitor=None,
|
26 |
+
batch_resize_range=None,
|
27 |
+
scheduler_config=None,
|
28 |
+
lr_g_factor=1.0,
|
29 |
+
remap=None,
|
30 |
+
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
31 |
+
use_ema=False
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
self.embed_dim = embed_dim
|
35 |
+
self.n_embed = n_embed
|
36 |
+
self.image_key = image_key
|
37 |
+
self.encoder = Encoder(**ddconfig)
|
38 |
+
self.decoder = Decoder(**ddconfig)
|
39 |
+
self.loss = instantiate_from_config(lossconfig)
|
40 |
+
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
41 |
+
remap=remap,
|
42 |
+
sane_index_shape=sane_index_shape)
|
43 |
+
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
44 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
45 |
+
if colorize_nlabels is not None:
|
46 |
+
assert type(colorize_nlabels)==int
|
47 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
48 |
+
if monitor is not None:
|
49 |
+
self.monitor = monitor
|
50 |
+
self.batch_resize_range = batch_resize_range
|
51 |
+
if self.batch_resize_range is not None:
|
52 |
+
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
53 |
+
|
54 |
+
self.use_ema = use_ema
|
55 |
+
if self.use_ema:
|
56 |
+
self.model_ema = LitEma(self)
|
57 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
58 |
+
|
59 |
+
if ckpt_path is not None:
|
60 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
61 |
+
self.scheduler_config = scheduler_config
|
62 |
+
self.lr_g_factor = lr_g_factor
|
63 |
+
|
64 |
+
@contextmanager
|
65 |
+
def ema_scope(self, context=None):
|
66 |
+
if self.use_ema:
|
67 |
+
self.model_ema.store(self.parameters())
|
68 |
+
self.model_ema.copy_to(self)
|
69 |
+
if context is not None:
|
70 |
+
print(f"{context}: Switched to EMA weights")
|
71 |
+
try:
|
72 |
+
yield None
|
73 |
+
finally:
|
74 |
+
if self.use_ema:
|
75 |
+
self.model_ema.restore(self.parameters())
|
76 |
+
if context is not None:
|
77 |
+
print(f"{context}: Restored training weights")
|
78 |
+
|
79 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
80 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
81 |
+
keys = list(sd.keys())
|
82 |
+
for k in keys:
|
83 |
+
for ik in ignore_keys:
|
84 |
+
if k.startswith(ik):
|
85 |
+
print("Deleting key {} from state_dict.".format(k))
|
86 |
+
del sd[k]
|
87 |
+
missing, unexpected = self.load_state_dict(sd, strict=False)
|
88 |
+
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
89 |
+
if len(missing) > 0:
|
90 |
+
print(f"Missing Keys: {missing}")
|
91 |
+
print(f"Unexpected Keys: {unexpected}")
|
92 |
+
|
93 |
+
def on_train_batch_end(self, *args, **kwargs):
|
94 |
+
if self.use_ema:
|
95 |
+
self.model_ema(self)
|
96 |
+
|
97 |
+
def encode(self, x):
|
98 |
+
h = self.encoder(x)
|
99 |
+
h = self.quant_conv(h)
|
100 |
+
quant, emb_loss, info = self.quantize(h)
|
101 |
+
return quant, emb_loss, info
|
102 |
+
|
103 |
+
def encode_to_prequant(self, x):
|
104 |
+
h = self.encoder(x)
|
105 |
+
h = self.quant_conv(h)
|
106 |
+
return h
|
107 |
+
|
108 |
+
def decode(self, quant):
|
109 |
+
quant = self.post_quant_conv(quant)
|
110 |
+
dec = self.decoder(quant)
|
111 |
+
return dec
|
112 |
+
|
113 |
+
def decode_code(self, code_b):
|
114 |
+
quant_b = self.quantize.embed_code(code_b)
|
115 |
+
dec = self.decode(quant_b)
|
116 |
+
return dec
|
117 |
+
|
118 |
+
def forward(self, input, return_pred_indices=False):
|
119 |
+
quant, diff, (_,_,ind) = self.encode(input)
|
120 |
+
dec = self.decode(quant)
|
121 |
+
if return_pred_indices:
|
122 |
+
return dec, diff, ind
|
123 |
+
return dec, diff
|
124 |
+
|
125 |
+
def get_input(self, batch, k):
|
126 |
+
x = batch[k]
|
127 |
+
if len(x.shape) == 3:
|
128 |
+
x = x[..., None]
|
129 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
130 |
+
if self.batch_resize_range is not None:
|
131 |
+
lower_size = self.batch_resize_range[0]
|
132 |
+
upper_size = self.batch_resize_range[1]
|
133 |
+
if self.global_step <= 4:
|
134 |
+
# do the first few batches with max size to avoid later oom
|
135 |
+
new_resize = upper_size
|
136 |
+
else:
|
137 |
+
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
138 |
+
if new_resize != x.shape[2]:
|
139 |
+
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
140 |
+
x = x.detach()
|
141 |
+
return x
|
142 |
+
|
143 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
144 |
+
# https://github.com/pytorch/pytorch/issues/37142
|
145 |
+
# try not to fool the heuristics
|
146 |
+
x = self.get_input(batch, self.image_key)
|
147 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
148 |
+
|
149 |
+
if optimizer_idx == 0:
|
150 |
+
# autoencode
|
151 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
152 |
+
last_layer=self.get_last_layer(), split="train",
|
153 |
+
predicted_indices=ind)
|
154 |
+
|
155 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
156 |
+
return aeloss
|
157 |
+
|
158 |
+
if optimizer_idx == 1:
|
159 |
+
# discriminator
|
160 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
161 |
+
last_layer=self.get_last_layer(), split="train")
|
162 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
163 |
+
return discloss
|
164 |
+
|
165 |
+
def validation_step(self, batch, batch_idx):
|
166 |
+
log_dict = self._validation_step(batch, batch_idx)
|
167 |
+
with self.ema_scope():
|
168 |
+
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
169 |
+
return log_dict
|
170 |
+
|
171 |
+
def _validation_step(self, batch, batch_idx, suffix=""):
|
172 |
+
x = self.get_input(batch, self.image_key)
|
173 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
174 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
175 |
+
self.global_step,
|
176 |
+
last_layer=self.get_last_layer(),
|
177 |
+
split="val"+suffix,
|
178 |
+
predicted_indices=ind
|
179 |
+
)
|
180 |
+
|
181 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
182 |
+
self.global_step,
|
183 |
+
last_layer=self.get_last_layer(),
|
184 |
+
split="val"+suffix,
|
185 |
+
predicted_indices=ind
|
186 |
+
)
|
187 |
+
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
188 |
+
self.log(f"val{suffix}/rec_loss", rec_loss,
|
189 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
190 |
+
self.log(f"val{suffix}/aeloss", aeloss,
|
191 |
+
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
192 |
+
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
193 |
+
del log_dict_ae[f"val{suffix}/rec_loss"]
|
194 |
+
self.log_dict(log_dict_ae)
|
195 |
+
self.log_dict(log_dict_disc)
|
196 |
+
return self.log_dict
|
197 |
+
|
198 |
+
def test_step(self, batch, batch_idx):
|
199 |
+
x = self.get_input(batch, self.image_key)
|
200 |
+
xrec, qloss, ind = self(x, return_pred_indices=True)
|
201 |
+
reconstructions = (xrec + 1)/2 # to mel scale
|
202 |
+
test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
|
203 |
+
savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
|
204 |
+
if not os.path.exists(savedir):
|
205 |
+
os.makedirs(savedir)
|
206 |
+
|
207 |
+
file_names = batch['f_name']
|
208 |
+
# print(f"reconstructions.shape:{reconstructions.shape}",file_names)
|
209 |
+
reconstructions = reconstructions.cpu().numpy().squeeze(1) # squuze channel dim
|
210 |
+
for b in range(reconstructions.shape[0]):
|
211 |
+
vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
|
212 |
+
v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
|
213 |
+
save_img_path = os.path.join(savedir,f'{v_n}_sample_{num}.npy')
|
214 |
+
np.save(save_img_path,reconstructions[b])
|
215 |
+
|
216 |
+
return None
|
217 |
+
|
218 |
+
def configure_optimizers(self):
|
219 |
+
lr_d = self.learning_rate
|
220 |
+
lr_g = self.lr_g_factor*self.learning_rate
|
221 |
+
print("lr_d", lr_d)
|
222 |
+
print("lr_g", lr_g)
|
223 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
224 |
+
list(self.decoder.parameters())+
|
225 |
+
list(self.quantize.parameters())+
|
226 |
+
list(self.quant_conv.parameters())+
|
227 |
+
list(self.post_quant_conv.parameters()),
|
228 |
+
lr=lr_g, betas=(0.5, 0.9))
|
229 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
230 |
+
lr=lr_d, betas=(0.5, 0.9))
|
231 |
+
|
232 |
+
if self.scheduler_config is not None:
|
233 |
+
scheduler = instantiate_from_config(self.scheduler_config)
|
234 |
+
|
235 |
+
print("Setting up LambdaLR scheduler...")
|
236 |
+
scheduler = [
|
237 |
+
{
|
238 |
+
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
239 |
+
'interval': 'step',
|
240 |
+
'frequency': 1
|
241 |
+
},
|
242 |
+
{
|
243 |
+
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
244 |
+
'interval': 'step',
|
245 |
+
'frequency': 1
|
246 |
+
},
|
247 |
+
]
|
248 |
+
return [opt_ae, opt_disc], scheduler
|
249 |
+
return [opt_ae, opt_disc], []
|
250 |
+
|
251 |
+
def get_last_layer(self):
|
252 |
+
return self.decoder.conv_out.weight
|
253 |
+
|
254 |
+
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
255 |
+
log = dict()
|
256 |
+
x = self.get_input(batch, self.image_key)
|
257 |
+
x = x.to(self.device)
|
258 |
+
if only_inputs:
|
259 |
+
log["inputs"] = x
|
260 |
+
return log
|
261 |
+
xrec, _ = self(x)
|
262 |
+
if x.shape[1] > 3:
|
263 |
+
# colorize with random projection
|
264 |
+
assert xrec.shape[1] > 3
|
265 |
+
x = self.to_rgb(x)
|
266 |
+
xrec = self.to_rgb(xrec)
|
267 |
+
log["inputs"] = x
|
268 |
+
log["reconstructions"] = xrec
|
269 |
+
if plot_ema:
|
270 |
+
with self.ema_scope():
|
271 |
+
xrec_ema, _ = self(x)
|
272 |
+
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
273 |
+
log["reconstructions_ema"] = xrec_ema
|
274 |
+
return log
|
275 |
+
|
276 |
+
def to_rgb(self, x):
|
277 |
+
assert self.image_key == "segmentation"
|
278 |
+
if not hasattr(self, "colorize"):
|
279 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
280 |
+
x = F.conv2d(x, weight=self.colorize)
|
281 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
282 |
+
return x
|
283 |
+
|
284 |
+
|
285 |
+
class VQModelInterface(VQModel):
|
286 |
+
def __init__(self, embed_dim, *args, **kwargs):
|
287 |
+
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
288 |
+
self.embed_dim = embed_dim
|
289 |
+
|
290 |
+
def encode(self, x):# VQModel的quantize写在encoder里,VQModelInterface则将其写在decoder里
|
291 |
+
h = self.encoder(x)
|
292 |
+
h = self.quant_conv(h)
|
293 |
+
return h
|
294 |
+
|
295 |
+
def decode(self, h, force_not_quantize=False):
|
296 |
+
# also go through quantization layer
|
297 |
+
if not force_not_quantize:
|
298 |
+
quant, emb_loss, info = self.quantize(h)
|
299 |
+
else:
|
300 |
+
quant = h
|
301 |
+
quant = self.post_quant_conv(quant)
|
302 |
+
dec = self.decoder(quant)
|
303 |
+
return dec
|
304 |
+
|
305 |
+
|
306 |
+
class AutoencoderKL(pl.LightningModule):
|
307 |
+
def __init__(self,
|
308 |
+
ddconfig,
|
309 |
+
lossconfig,
|
310 |
+
embed_dim,
|
311 |
+
ckpt_path=None,
|
312 |
+
ignore_keys=[],
|
313 |
+
image_key="image",
|
314 |
+
colorize_nlabels=None,
|
315 |
+
monitor=None,
|
316 |
+
):
|
317 |
+
super().__init__()
|
318 |
+
self.to_1d = False
|
319 |
+
print(f"to_1d is {self.to_1d} in AUTOENCODER")
|
320 |
+
self.image_key = image_key
|
321 |
+
self.encoder = Encoder(**ddconfig)
|
322 |
+
self.decoder = Decoder(**ddconfig)
|
323 |
+
self.loss = instantiate_from_config(lossconfig)
|
324 |
+
assert ddconfig["double_z"]
|
325 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
326 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
327 |
+
self.embed_dim = embed_dim
|
328 |
+
if colorize_nlabels is not None:
|
329 |
+
assert type(colorize_nlabels)==int
|
330 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
331 |
+
if monitor is not None:
|
332 |
+
self.monitor = monitor
|
333 |
+
if ckpt_path is not None:
|
334 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
335 |
+
# self.automatic_optimization = False # hjw for debug
|
336 |
+
|
337 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
338 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
339 |
+
keys = list(sd.keys())
|
340 |
+
for k in keys:
|
341 |
+
for ik in ignore_keys:
|
342 |
+
if k.startswith(ik):
|
343 |
+
print("Deleting key {} from state_dict.".format(k))
|
344 |
+
del sd[k]
|
345 |
+
self.load_state_dict(sd, strict=False)
|
346 |
+
print(f"Restored from {path}")
|
347 |
+
|
348 |
+
def encode(self, x):
|
349 |
+
if self.to_1d and len(x.shape)==3:
|
350 |
+
x = x.unsqueeze(1)
|
351 |
+
h = self.encoder(x)
|
352 |
+
moments = self.quant_conv(h)
|
353 |
+
if self.to_1d:
|
354 |
+
b,c,h,w = moments.shape
|
355 |
+
moments = moments.reshape(b,c*h,w)
|
356 |
+
posterior = DiagonalGaussianDistribution(moments)
|
357 |
+
return posterior
|
358 |
+
|
359 |
+
def decode(self, z):
|
360 |
+
if self.to_1d:
|
361 |
+
b,c_h,w = z.shape
|
362 |
+
c = self.post_quant_conv.in_channels
|
363 |
+
z = z.reshape(b,c,-1,w)
|
364 |
+
z = self.post_quant_conv(z)
|
365 |
+
dec = self.decoder(z)
|
366 |
+
return dec
|
367 |
+
|
368 |
+
def forward(self, input, sample_posterior=True):
|
369 |
+
posterior = self.encode(input)
|
370 |
+
if sample_posterior:
|
371 |
+
z = posterior.sample()
|
372 |
+
else:
|
373 |
+
z = posterior.mode()
|
374 |
+
dec = self.decode(z)
|
375 |
+
return dec, posterior
|
376 |
+
|
377 |
+
def get_input(self, batch, k):
|
378 |
+
x = batch[k]
|
379 |
+
if len(x.shape) == 3:
|
380 |
+
x = x[..., None]
|
381 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
382 |
+
return x
|
383 |
+
|
384 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
385 |
+
inputs = self.get_input(batch, self.image_key)
|
386 |
+
reconstructions, posterior = self(inputs)
|
387 |
+
|
388 |
+
if optimizer_idx == 0:
|
389 |
+
# train encoder+decoder+logvar
|
390 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
391 |
+
last_layer=self.get_last_layer(), split="train")
|
392 |
+
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
393 |
+
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
394 |
+
# print(optimizer_idx,log_dict_ae)
|
395 |
+
return aeloss
|
396 |
+
|
397 |
+
if optimizer_idx == 1:
|
398 |
+
# train the discriminator
|
399 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
400 |
+
last_layer=self.get_last_layer(), split="train")
|
401 |
+
|
402 |
+
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
403 |
+
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
404 |
+
# print(optimizer_idx,log_dict_disc)
|
405 |
+
return discloss
|
406 |
+
|
407 |
+
def validation_step(self, batch, batch_idx):
|
408 |
+
inputs = self.get_input(batch, self.image_key)
|
409 |
+
reconstructions, posterior = self(inputs)
|
410 |
+
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
411 |
+
last_layer=self.get_last_layer(), split="val")
|
412 |
+
|
413 |
+
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
414 |
+
last_layer=self.get_last_layer(), split="val")
|
415 |
+
|
416 |
+
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
417 |
+
self.log_dict(log_dict_ae)
|
418 |
+
self.log_dict(log_dict_disc)
|
419 |
+
return self.log_dict
|
420 |
+
|
421 |
+
def test_step(self, batch, batch_idx):
|
422 |
+
inputs = self.get_input(batch, self.image_key)# inputs shape:(b,mel_len,T)
|
423 |
+
reconstructions, posterior = self(inputs)# reconstructions:(b,mel_len,T)
|
424 |
+
mse_loss = torch.nn.functional.mse_loss(reconstructions,inputs)
|
425 |
+
self.log('test/mse_loss',mse_loss)
|
426 |
+
|
427 |
+
test_ckpt_path = os.path.basename(self.trainer.tested_ckpt_path)
|
428 |
+
savedir = os.path.join(self.trainer.log_dir,f'output_imgs_{test_ckpt_path}','fake_class')
|
429 |
+
if batch_idx == 0:
|
430 |
+
print(f"save_path is: {savedir}")
|
431 |
+
if not os.path.exists(savedir):
|
432 |
+
os.makedirs(savedir)
|
433 |
+
print(f"save_path is: {savedir}")
|
434 |
+
|
435 |
+
file_names = batch['f_name']
|
436 |
+
# print(f"reconstructions.shape:{reconstructions.shape}",file_names)
|
437 |
+
# reconstructions = (reconstructions + 1)/2 # to mel scale
|
438 |
+
reconstructions = reconstructions.cpu().numpy().squeeze(1) # squeeze channel dim
|
439 |
+
for b in range(reconstructions.shape[0]):
|
440 |
+
vname_num_split_index = file_names[b].rfind('_')# file_names[b]:video_name+'_'+num
|
441 |
+
v_n,num = file_names[b][:vname_num_split_index],file_names[b][vname_num_split_index+1:]
|
442 |
+
save_img_path = os.path.join(savedir, f'{v_n}.npy') # f'{v_n}_sample_{num}.npy' f'{v_n}.npy'
|
443 |
+
np.save(save_img_path,reconstructions[b])
|
444 |
+
|
445 |
+
return None
|
446 |
+
|
447 |
+
def configure_optimizers(self):
|
448 |
+
lr = self.learning_rate
|
449 |
+
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
450 |
+
list(self.decoder.parameters())+
|
451 |
+
list(self.quant_conv.parameters())+
|
452 |
+
list(self.post_quant_conv.parameters()),
|
453 |
+
lr=lr, betas=(0.5, 0.9))
|
454 |
+
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
455 |
+
lr=lr, betas=(0.5, 0.9))
|
456 |
+
return [opt_ae, opt_disc], []
|
457 |
+
|
458 |
+
def get_last_layer(self):
|
459 |
+
return self.decoder.conv_out.weight
|
460 |
+
|
461 |
+
@torch.no_grad()
|
462 |
+
def log_images(self, batch, only_inputs=False,save_dir = 'mel_result_ae13_26_debug/fake_class', **kwargs): # 在main.py的on_validation_batch_end中调用
|
463 |
+
log = dict()
|
464 |
+
x = self.get_input(batch, self.image_key)
|
465 |
+
x = x.to(self.device)
|
466 |
+
if not only_inputs:
|
467 |
+
xrec, posterior = self(x)
|
468 |
+
if x.shape[1] > 3:
|
469 |
+
# colorize with random projection
|
470 |
+
assert xrec.shape[1] > 3
|
471 |
+
x = self.to_rgb(x)
|
472 |
+
xrec = self.to_rgb(xrec)
|
473 |
+
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
474 |
+
log["reconstructions"] = xrec
|
475 |
+
log["inputs"] = x
|
476 |
+
return log
|
477 |
+
|
478 |
+
def to_rgb(self, x):
|
479 |
+
assert self.image_key == "segmentation"
|
480 |
+
if not hasattr(self, "colorize"):
|
481 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
482 |
+
x = F.conv2d(x, weight=self.colorize)
|
483 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
484 |
+
return x
|
485 |
+
|
486 |
+
|
487 |
+
class IdentityFirstStage(torch.nn.Module):
|
488 |
+
def __init__(self, *args, vq_interface=False, **kwargs):
|
489 |
+
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
490 |
+
super().__init__()
|
491 |
+
|
492 |
+
def encode(self, x, *args, **kwargs):
|
493 |
+
return x
|
494 |
+
|
495 |
+
def decode(self, x, *args, **kwargs):
|
496 |
+
return x
|
497 |
+
|
498 |
+
def quantize(self, x, *args, **kwargs):
|
499 |
+
if self.vq_interface:
|
500 |
+
return x, None, [None, None, None]
|
501 |
+
return x
|
502 |
+
|
503 |
+
def forward(self, x, *args, **kwargs):
|
504 |
+
return x
|