rvc-inference / app_multi.py
DJQmUKV
fix: another fix attempt, what the f**k is happening
f98b802
raw
history blame
11.6 kB
from typing import Union
from argparse import ArgumentParser
import asyncio
import json
from os import path
import gradio as gr
import torch
import numpy as np
import librosa
import edge_tts
from config import device
import util
from infer_pack.models import (
SynthesizerTrnMs256NSFsid,
SynthesizerTrnMs256NSFsid_nono
)
from vc_infer_pipeline import VC
# Argument parsing
arg_parser = ArgumentParser()
arg_parser.add_argument(
'--hubert',
default='hubert_base.pt',
help='path to hubert base model (default: hubert_base.pt)'
)
arg_parser.add_argument(
'--config',
default='multi_config.json',
help='path to config file (default: multi_config.json)'
)
arg_parser.add_argument(
'--bind',
default='127.0.0.1',
help='gradio server listen address (default: 127.0.0.1)'
)
arg_parser.add_argument(
'--port',
default=7860,
help='gradio server listen port (default: 7860)'
)
arg_parser.add_argument(
'--share',
action='store_true',
help='let gradio create a public link for you'
)
arg_parser.add_argument(
'--api',
action='store_true',
help='enable api endpoint'
)
arg_parser.add_argument(
'--cache-examples',
action='store_true',
help='enable example caching, please remember delete gradio_cached_examples folder when example config has been modified' # noqa
)
args = arg_parser.parse_args()
app_css = '''
#model_info img {
max-width: 100px;
max-height: 100px;
float: right;
}
#model_info p {
margin: unset;
}
'''
app = gr.Blocks(
theme=gr.themes.Glass(),
css=app_css,
analytics_enabled=False
)
# Load hubert model
hubert_model = util.load_hubert_model(device, args.hubert)
hubert_model.eval()
# Load models
multi_cfg = json.load(open(args.config, 'r'))
loaded_models = []
for model_name in multi_cfg.get('models'):
print(f'Loading model: {model_name}')
# Load model info
model_info = json.load(
open(path.join('model', model_name, 'config.json'), 'r')
)
# Load RVC checkpoint
cpt = torch.load(
path.join('model', model_name, model_info['model']),
map_location='cpu'
)
tgt_sr = cpt['config'][-1]
cpt['config'][-3] = cpt['weight']['emb_g.weight'].shape[0] # n_spk
if_f0 = cpt.get('f0', 1)
net_g: Union[SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono]
if if_f0 == 1:
net_g = SynthesizerTrnMs256NSFsid(
*cpt['config'],
is_half=util.is_half(device)
)
else:
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt['config'])
del net_g.enc_q
# According to original code, this thing seems necessary.
print(net_g.load_state_dict(cpt['weight'], strict=False))
net_g.eval().to(device)
net_g = net_g.half() if util.is_half(device) else net_g.float()
vc = VC(tgt_sr, device, util.is_half(device))
loaded_models.append(dict(
name=model_name,
metadata=model_info,
vc=vc,
net_g=net_g,
if_f0=if_f0,
target_sr=tgt_sr
))
print(f'Models loaded: {len(loaded_models)}')
# Edge TTS speakers
tts_speakers_list = asyncio.run(edge_tts.list_voices())
# https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/main/infer-web.py#L118 # noqa
def vc_func(input_audio, model_index, pitch_adjust, f0_method, feat_ratio):
if input_audio is None:
return (None, 'Please provide input audio.')
if model_index is None:
return (None, 'Please select a model.')
model = loaded_models[model_index]
# Reference: so-vits
(audio_samp, audio_npy) = input_audio
# Bloody hell: https://stackoverflow.com/questions/26921836/
if audio_npy.dtype != np.float32: # :thonk:
audio_npy = (
audio_npy / np.iinfo(audio_npy.dtype).max
).astype(np.float32)
if len(audio_npy.shape) > 1:
audio_npy = librosa.to_mono(audio_npy.transpose(1, 0))
if audio_samp != 16000:
audio_npy = librosa.resample(
audio_npy,
orig_sr=audio_samp,
target_sr=16000
)
pitch_int = int(pitch_adjust)
times = [0, 0, 0]
output_audio = model['vc'].pipeline(
hubert_model,
model['net_g'],
model['metadata'].get('speaker_id', 0),
audio_npy,
times,
pitch_int,
f0_method,
path.join('model', model['name'], model['metadata']['feat_index']),
path.join('model', model['name'], model['metadata']['feat_npy']),
feat_ratio,
model['if_f0']
)
print(f'npy: {times[0]}s, f0: {times[1]}s, infer: {times[2]}s')
return ((model['target_sr'], output_audio), 'Success')
async def edge_tts_vc_func(
input_text, model_index, tts_speaker, pitch_adjust, f0_method, feat_ratio
):
if input_text is None:
return (None, 'Please provide TTS text.')
if tts_speaker is None:
return (None, 'Please select TTS speaker.')
if model_index is None:
return (None, 'Please select a model.')
speaker = tts_speakers_list[tts_speaker]['ShortName']
(tts_np, tts_sr) = await util.call_edge_tts(speaker, input_text)
return vc_func(
(tts_sr, tts_np),
model_index,
pitch_adjust,
f0_method,
feat_ratio
)
def update_model_info(model_index):
if model_index is None:
return str(
'### Model info\n'
'Please select a model from dropdown above.'
)
model = loaded_models[model_index]
model_icon = model['metadata'].get('icon', '')
return str(
'### Model info\n'
'![model icon]({icon})'
'**{name}**\n\n'
'Author: {author}\n\n'
'Source: {source}\n\n'
'{note}'
).format(
name=model['metadata'].get('name'),
author=model['metadata'].get('author', 'Anonymous'),
source=model['metadata'].get('source', 'Unknown'),
note=model['metadata'].get('note', ''),
icon=(
model_icon
if model_icon.startswith(('http://', 'https://'))
else '/file/model/%s/%s' % (model['name'], model_icon)
)
)
def _example_vc(input_audio, model_index, pitch_adjust, f0_method, feat_ratio):
(audio, message) = vc_func(
input_audio, model_index, pitch_adjust, f0_method, feat_ratio
)
return (
audio,
message,
update_model_info(model_index)
)
async def _example_edge_tts(
input_text, model_index, tts_speaker, pitch_adjust, f0_method, feat_ratio
):
(audio, message) = await edge_tts_vc_func(
input_text, model_index, tts_speaker, pitch_adjust, f0_method,
feat_ratio
)
return (
audio,
message,
update_model_info(model_index)
)
with app:
gr.Markdown(
'## Simple, Stupid RVC Inference WebUI\n'
'Another RVC inference WebUI based on [RVC-WebUI](https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI), ' # noqa
'some code and features inspired from so-vits and [zomehwh/rvc-models](https://huggingface.co/spaces/zomehwh/rvc-models).\n' # noqa
)
with gr.Row():
with gr.Column():
with gr.Tab('Audio conversion'):
input_audio = gr.Audio(label='Input audio')
vc_convert_btn = gr.Button('Convert', variant='primary')
with gr.Tab('TTS conversion'):
tts_input = gr.TextArea(
label='TTS input text'
)
tts_speaker = gr.Dropdown(
[
'%s (%s)' % (
s['FriendlyName'],
s['Gender']
)
for s in tts_speakers_list
],
label='TTS speaker',
type='index'
)
tts_convert_btn = gr.Button('Convert', variant='primary')
pitch_adjust = gr.Slider(
label='Pitch',
minimum=-24,
maximum=24,
step=1,
value=0
)
f0_method = gr.Radio(
label='f0 methods',
choices=['pm', 'harvest'],
value='pm',
interactive=True
)
feat_ratio = gr.Slider(
label='Feature ratio',
minimum=0,
maximum=1,
step=0.1,
value=0.6
)
with gr.Column():
# Model select
model_index = gr.Dropdown(
[
'%s - %s' % (
m['metadata'].get('source', 'Unknown'),
m['metadata'].get('name')
)
for m in loaded_models
],
label='Model',
type='index'
)
# Model info
with gr.Box():
model_info = gr.Markdown(
'### Model info\n'
'Please select a model from dropdown above.',
elem_id='model_info'
)
output_audio = gr.Audio(label='Output audio')
output_msg = gr.Textbox(label='Output message')
multi_examples = multi_cfg.get('examples')
if multi_examples:
with gr.Accordion('Sweet sweet examples', open=False):
with gr.Row():
# VC Example
if multi_examples.get('vc'):
gr.Examples(
label='Audio conversion examples',
examples=multi_examples.get('vc'),
inputs=[
input_audio, model_index, pitch_adjust, f0_method,
feat_ratio
],
outputs=[output_audio, output_msg, model_info],
fn=_example_vc,
cache_examples=args.cache_examples,
run_on_click=args.cache_examples
)
# Edge TTS Example
if multi_examples.get('tts_vc'):
gr.Examples(
label='TTS conversion examples',
examples=multi_examples.get('tts_vc'),
inputs=[
tts_input, model_index, tts_speaker, pitch_adjust,
f0_method, feat_ratio
],
outputs=[output_audio, output_msg, model_info],
fn=_example_edge_tts,
cache_examples=args.cache_examples,
run_on_click=args.cache_examples
)
vc_convert_btn.click(
vc_func,
[input_audio, model_index, pitch_adjust, f0_method, feat_ratio],
[output_audio, output_msg],
api_name='audio_conversion'
)
tts_convert_btn.click(
edge_tts_vc_func,
[
tts_input, model_index, tts_speaker, pitch_adjust, f0_method,
feat_ratio
],
[output_audio, output_msg],
api_name='tts_conversion'
)
model_index.change(
update_model_info,
inputs=[model_index],
outputs=[model_info],
show_progress=False,
queue=False
)
app.queue(
concurrency_count=1,
max_size=20,
api_open=args.api
).launch(
server_name=args.bind,
server_port=args.port,
share=args.share
)