Spaces:
Running
on
Zero
Running
on
Zero
mrfakename
commited on
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/eval/eval_infer_batch.py
CHANGED
@@ -187,7 +187,7 @@ def main():
|
|
187 |
# Final result
|
188 |
for i, gen in enumerate(generated):
|
189 |
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
190 |
-
gen_mel_spec = gen.permute(0, 2, 1)
|
191 |
if mel_spec_type == "vocos":
|
192 |
generated_wave = vocoder.decode(gen_mel_spec)
|
193 |
elif mel_spec_type == "bigvgan":
|
@@ -195,7 +195,7 @@ def main():
|
|
195 |
|
196 |
if ref_rms_list[i] < target_rms:
|
197 |
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
198 |
-
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.
|
199 |
|
200 |
accelerator.wait_for_everyone()
|
201 |
if accelerator.is_main_process:
|
|
|
187 |
# Final result
|
188 |
for i, gen in enumerate(generated):
|
189 |
gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
|
190 |
+
gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
|
191 |
if mel_spec_type == "vocos":
|
192 |
generated_wave = vocoder.decode(gen_mel_spec)
|
193 |
elif mel_spec_type == "bigvgan":
|
|
|
195 |
|
196 |
if ref_rms_list[i] < target_rms:
|
197 |
generated_wave = generated_wave * ref_rms_list[i] / target_rms
|
198 |
+
torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.cpu(), target_sample_rate)
|
199 |
|
200 |
accelerator.wait_for_everyone()
|
201 |
if accelerator.is_main_process:
|
src/f5_tts/infer/infer_cli.py
CHANGED
@@ -109,13 +109,16 @@ ckpt_file = args.ckpt_file if args.ckpt_file else ""
|
|
109 |
vocab_file = args.vocab_file if args.vocab_file else ""
|
110 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
111 |
speed = args.speed
|
|
|
112 |
wave_path = Path(output_dir) / "infer_cli_out.wav"
|
113 |
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
114 |
-
|
|
|
|
|
|
|
115 |
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
116 |
-
elif
|
117 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
118 |
-
mel_spec_type = args.vocoder_name
|
119 |
|
120 |
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
|
121 |
|
@@ -125,19 +128,20 @@ if model == "F5-TTS":
|
|
125 |
model_cls = DiT
|
126 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
127 |
if ckpt_file == "":
|
128 |
-
if
|
129 |
repo_name = "F5-TTS"
|
130 |
exp_name = "F5TTS_Base"
|
131 |
ckpt_step = 1200000
|
132 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
133 |
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
134 |
-
elif
|
135 |
repo_name = "F5-TTS"
|
136 |
exp_name = "F5TTS_Base_bigvgan"
|
137 |
ckpt_step = 1250000
|
138 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
139 |
|
140 |
elif model == "E2-TTS":
|
|
|
141 |
model_cls = UNetT
|
142 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
143 |
if ckpt_file == "":
|
@@ -146,15 +150,10 @@ elif model == "E2-TTS":
|
|
146 |
ckpt_step = 1200000
|
147 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
148 |
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
149 |
-
elif args.vocoder_name == "bigvgan": # TODO: need to test
|
150 |
-
repo_name = "F5-TTS"
|
151 |
-
exp_name = "F5TTS_Base_bigvgan"
|
152 |
-
ckpt_step = 1250000
|
153 |
-
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
154 |
|
155 |
|
156 |
print(f"Using {model}...")
|
157 |
-
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=
|
158 |
|
159 |
|
160 |
def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
|
|
|
109 |
vocab_file = args.vocab_file if args.vocab_file else ""
|
110 |
remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
|
111 |
speed = args.speed
|
112 |
+
|
113 |
wave_path = Path(output_dir) / "infer_cli_out.wav"
|
114 |
# spectrogram_path = Path(output_dir) / "infer_cli_out.png"
|
115 |
+
|
116 |
+
vocoder_name = args.vocoder_name
|
117 |
+
mel_spec_type = args.vocoder_name
|
118 |
+
if vocoder_name == "vocos":
|
119 |
vocoder_local_path = "../checkpoints/vocos-mel-24khz"
|
120 |
+
elif vocoder_name == "bigvgan":
|
121 |
vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
|
|
|
122 |
|
123 |
vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
|
124 |
|
|
|
128 |
model_cls = DiT
|
129 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
130 |
if ckpt_file == "":
|
131 |
+
if vocoder_name == "vocos":
|
132 |
repo_name = "F5-TTS"
|
133 |
exp_name = "F5TTS_Base"
|
134 |
ckpt_step = 1200000
|
135 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
136 |
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
137 |
+
elif vocoder_name == "bigvgan":
|
138 |
repo_name = "F5-TTS"
|
139 |
exp_name = "F5TTS_Base_bigvgan"
|
140 |
ckpt_step = 1250000
|
141 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
|
142 |
|
143 |
elif model == "E2-TTS":
|
144 |
+
assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
|
145 |
model_cls = UNetT
|
146 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
147 |
if ckpt_file == "":
|
|
|
150 |
ckpt_step = 1200000
|
151 |
ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
|
152 |
# ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
|
155 |
print(f"Using {model}...")
|
156 |
+
ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)
|
157 |
|
158 |
|
159 |
def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
|
src/f5_tts/infer/speech_edit.py
CHANGED
@@ -187,5 +187,5 @@ with torch.inference_mode():
|
|
187 |
generated_wave = generated_wave * rms / target_rms
|
188 |
|
189 |
save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
|
190 |
-
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.
|
191 |
print(f"Generated wav: {generated_wave.shape}")
|
|
|
187 |
generated_wave = generated_wave * rms / target_rms
|
188 |
|
189 |
save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
|
190 |
+
torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.cpu(), target_sample_rate)
|
191 |
print(f"Generated wav: {generated_wave.shape}")
|