import gradio as gr import torch import torchaudio from torch.cuda.amp import autocast from network.models import FilterBandTFGridnet, ResemblyzerVoiceEncoder # device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cpu' model = FilterBandTFGridnet(n_layers=5,conditional_dim=256*2) emb = ResemblyzerVoiceEncoder(device=device) mixed_voice_tool = None def load_voice(voice_path): voice, rate = torchaudio.load(voice_path) if rate != 16000: voice = torchaudio.functional.resample(voice, rate, 8000) rate = 16000 voice = voice.float() return voice, rate def mix(voice1_path, voice2_path, snr=0): global mixed_voice_tool voice1, _ = load_voice(voice1_path) voice2, _ = load_voice(voice2_path) mix = torchaudio.functional.add_noise(voice1, voice2, torch.tensor([float(snr)])).float() mixed_voice_tool = mix return gr.Audio(tuple((16000,mix[0].numpy())),type='numpy') # def seprate_from_file(mixed_voice, ref_voice): def seprate(mixed_voice_path, clean_voice_path, drop_down): if drop_down == 'From mixing tool': mixed_voice = mixed_voice_tool else: mixed_voice,rate = load_voice(mixed_voice_path) clean_voice,rate = load_voice(clean_voice_path) if clean_voice.shape[-1] < 16000*4: n = 16000*4 // clean_voice.shape[-1] + 1 clean_voice = torch.cat([clean_voice]*n, dim=-1) clean_voice = clean_voice[:,:16000*4] if not model: return None model.to(device) model.eval() e = emb(clean_voice) e_mix = emb(mixed_voice) e = torch.cat([e,e_mix],dim=1) mixed_voice = torchaudio.functional.resample(mixed_voice, rate, 8000) with autocast(): with torch.no_grad(): yHat = model( mixed_voice, e, ) yHat = torchaudio.functional.resample(yHat, 8000, 16000).numpy().astype('float32') audio = gr.Audio(tuple((16000,yHat[0])),type='numpy') return audio def load_checkpoint(filepath): checkpoint = torch.load( filepath, weights_only=True, map_location=device, ) model.load_state_dict(checkpoint) with gr.Blocks() as demo: load_checkpoint('checkpoints/concat_emb.pth') with gr.Row(): snr = gr.Slider(label='SNR', minimum=-10, maximum=10, step=1, value=0) with gr.Row(): with gr.Column(scale=1,min_width=200): voice1 = gr.Audio(label='speaker 1', type='filepath') with gr.Column(scale=1,min_width=200): voice2 = gr.Audio(label='speaker 2', type='filepath') with gr.Column(scale=1,min_width=200): with gr.Row(): mixed_voice = gr.Audio(label='Mixed voice') with gr.Row(): btn = gr.Button("Mix voices", size='sm') btn.click(mix, inputs=[voice1, voice2, snr], outputs=mixed_voice) with gr.Row(): choose_mix_source = gr.Label('Extract target speaker voice from mixed voice') with gr.Row(): drop_down = gr.Dropdown(['From mixing tool', 'Upload'], label='Choose mixed voice source') with gr.Row(): with gr.Column(scale=1,min_width=200): with gr.Row(): mixed_voice_path = gr.Audio(label='Mixed voice', type='filepath') with gr.Column(scale=1,min_width=200): with gr.Row(): ref_voice_path = gr.Audio(label='reference voice', type='filepath') with gr.Column(scale=1,min_width=200): with gr.Row(): sep_voice = gr.Audio(label="Separate Voice") with gr.Row(): btn = gr.Button("Separate voices", size='sm') btn.click(seprate, inputs=[mixed_voice_path, ref_voice_path, drop_down], outputs=sep_voice) demo.launch()