mrfakename commited on
Commit
b4fc33b
·
verified ·
1 Parent(s): 75e5a1a

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/eval/eval_infer_batch.py CHANGED
@@ -187,7 +187,7 @@ def main():
187
  # Final result
188
  for i, gen in enumerate(generated):
189
  gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
190
- gen_mel_spec = gen.permute(0, 2, 1)
191
  if mel_spec_type == "vocos":
192
  generated_wave = vocoder.decode(gen_mel_spec)
193
  elif mel_spec_type == "bigvgan":
@@ -195,7 +195,7 @@ def main():
195
 
196
  if ref_rms_list[i] < target_rms:
197
  generated_wave = generated_wave * ref_rms_list[i] / target_rms
198
- torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
199
 
200
  accelerator.wait_for_everyone()
201
  if accelerator.is_main_process:
 
187
  # Final result
188
  for i, gen in enumerate(generated):
189
  gen = gen[ref_mel_lens[i] : total_mel_lens[i], :].unsqueeze(0)
190
+ gen_mel_spec = gen.permute(0, 2, 1).to(torch.float32)
191
  if mel_spec_type == "vocos":
192
  generated_wave = vocoder.decode(gen_mel_spec)
193
  elif mel_spec_type == "bigvgan":
 
195
 
196
  if ref_rms_list[i] < target_rms:
197
  generated_wave = generated_wave * ref_rms_list[i] / target_rms
198
+ torchaudio.save(f"{output_dir}/{utts[i]}.wav", generated_wave.cpu(), target_sample_rate)
199
 
200
  accelerator.wait_for_everyone()
201
  if accelerator.is_main_process:
src/f5_tts/infer/infer_cli.py CHANGED
@@ -109,13 +109,16 @@ ckpt_file = args.ckpt_file if args.ckpt_file else ""
109
  vocab_file = args.vocab_file if args.vocab_file else ""
110
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
111
  speed = args.speed
 
112
  wave_path = Path(output_dir) / "infer_cli_out.wav"
113
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
114
- if args.vocoder_name == "vocos":
 
 
 
115
  vocoder_local_path = "../checkpoints/vocos-mel-24khz"
116
- elif args.vocoder_name == "bigvgan":
117
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
118
- mel_spec_type = args.vocoder_name
119
 
120
  vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
121
 
@@ -125,19 +128,20 @@ if model == "F5-TTS":
125
  model_cls = DiT
126
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
127
  if ckpt_file == "":
128
- if args.vocoder_name == "vocos":
129
  repo_name = "F5-TTS"
130
  exp_name = "F5TTS_Base"
131
  ckpt_step = 1200000
132
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
133
  # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
134
- elif args.vocoder_name == "bigvgan":
135
  repo_name = "F5-TTS"
136
  exp_name = "F5TTS_Base_bigvgan"
137
  ckpt_step = 1250000
138
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
139
 
140
  elif model == "E2-TTS":
 
141
  model_cls = UNetT
142
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
143
  if ckpt_file == "":
@@ -146,15 +150,10 @@ elif model == "E2-TTS":
146
  ckpt_step = 1200000
147
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
148
  # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
149
- elif args.vocoder_name == "bigvgan": # TODO: need to test
150
- repo_name = "F5-TTS"
151
- exp_name = "F5TTS_Base_bigvgan"
152
- ckpt_step = 1250000
153
- ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
154
 
155
 
156
  print(f"Using {model}...")
157
- ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=args.vocoder_name, vocab_file=vocab_file)
158
 
159
 
160
  def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
 
109
  vocab_file = args.vocab_file if args.vocab_file else ""
110
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
111
  speed = args.speed
112
+
113
  wave_path = Path(output_dir) / "infer_cli_out.wav"
114
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
115
+
116
+ vocoder_name = args.vocoder_name
117
+ mel_spec_type = args.vocoder_name
118
+ if vocoder_name == "vocos":
119
  vocoder_local_path = "../checkpoints/vocos-mel-24khz"
120
+ elif vocoder_name == "bigvgan":
121
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
 
122
 
123
  vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
124
 
 
128
  model_cls = DiT
129
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
130
  if ckpt_file == "":
131
+ if vocoder_name == "vocos":
132
  repo_name = "F5-TTS"
133
  exp_name = "F5TTS_Base"
134
  ckpt_step = 1200000
135
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
136
  # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
137
+ elif vocoder_name == "bigvgan":
138
  repo_name = "F5-TTS"
139
  exp_name = "F5TTS_Base_bigvgan"
140
  ckpt_step = 1250000
141
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
142
 
143
  elif model == "E2-TTS":
144
+ assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
145
  model_cls = UNetT
146
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
147
  if ckpt_file == "":
 
150
  ckpt_step = 1200000
151
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
152
  # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
 
 
 
 
 
153
 
154
 
155
  print(f"Using {model}...")
156
+ ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)
157
 
158
 
159
  def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
src/f5_tts/infer/speech_edit.py CHANGED
@@ -187,5 +187,5 @@ with torch.inference_mode():
187
  generated_wave = generated_wave * rms / target_rms
188
 
189
  save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
190
- torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.squeeze(0).cpu(), target_sample_rate)
191
  print(f"Generated wav: {generated_wave.shape}")
 
187
  generated_wave = generated_wave * rms / target_rms
188
 
189
  save_spectrogram(gen_mel_spec[0].cpu().numpy(), f"{output_dir}/speech_edit_out.png")
190
+ torchaudio.save(f"{output_dir}/speech_edit_out.wav", generated_wave.cpu(), target_sample_rate)
191
  print(f"Generated wav: {generated_wave.shape}")