File size: 3,794 Bytes
7596274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import torch
import torchaudio
from torch.cuda.amp import autocast

from network.models import FilterBandTFGridnet, ResemblyzerVoiceEncoder

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
model = FilterBandTFGridnet(n_layers=5,conditional_dim=256*2)
emb = ResemblyzerVoiceEncoder(device=device)
mixed_voice_tool = None

def load_voice(voice_path):
    voice, rate = torchaudio.load(voice_path)
    
    if rate != 16000:
        voice = torchaudio.functional.resample(voice, rate, 8000)
        rate = 16000
    voice = voice.float()
    return voice, rate

def mix(voice1_path, voice2_path, snr=0):
    global mixed_voice_tool
    voice1, _ = load_voice(voice1_path)
    voice2, _ = load_voice(voice2_path)
    mix = torchaudio.functional.add_noise(voice1, voice2, torch.tensor([float(snr)])).float()
    mixed_voice_tool = mix
    return gr.Audio(tuple((16000,mix[0].numpy())),type='numpy')

# def seprate_from_file(mixed_voice, ref_voice):

def seprate(mixed_voice_path, clean_voice_path, drop_down):
    if drop_down == 'From mixing tool':
        mixed_voice = mixed_voice_tool
    else:
        mixed_voice,rate = load_voice(mixed_voice_path)
    clean_voice,rate = load_voice(clean_voice_path)
    if clean_voice.shape[-1] < 16000*4:
        n = 16000*4 // clean_voice.shape[-1] + 1
        clean_voice = torch.cat([clean_voice]*n, dim=-1)
        clean_voice = clean_voice[:,:16000*4] 
    if not model:
        return None
    model.to(device)
    model.eval()
    e = emb(clean_voice)
    e_mix = emb(mixed_voice)
    e = torch.cat([e,e_mix],dim=1)
    mixed_voice = torchaudio.functional.resample(mixed_voice, rate, 8000)
    with autocast():
        with torch.no_grad():
            yHat = model(
                mixed_voice,
                e,
            )
        yHat = torchaudio.functional.resample(yHat, 8000, 16000).numpy().astype('float32')
        audio = gr.Audio(tuple((16000,yHat[0])),type='numpy')
        return audio

def load_checkpoint(filepath):
    checkpoint = torch.load(
        filepath,
        weights_only=True,
        map_location=device,
    )
    model.load_state_dict(checkpoint)

with gr.Blocks() as demo:
    load_checkpoint('checkpoints/concat_emb.pth')
    with gr.Row():
        snr = gr.Slider(label='SNR', minimum=-10, maximum=10, step=1, value=0)
    with gr.Row():
        with gr.Column(scale=1,min_width=200):
            voice1 = gr.Audio(label='speaker 1', type='filepath')
        with gr.Column(scale=1,min_width=200):
            voice2 = gr.Audio(label='speaker 2', type='filepath')
        with gr.Column(scale=1,min_width=200):
            with gr.Row():
                mixed_voice = gr.Audio(label='Mixed voice')
            with gr.Row():
                btn = gr.Button("Mix voices", size='sm')
                btn.click(mix, inputs=[voice1, voice2, snr], outputs=mixed_voice)
    with gr.Row():
        choose_mix_source = gr.Label('Extract target speaker voice from mixed voice')
    with gr.Row():
        drop_down = gr.Dropdown(['From mixing tool', 'Upload'], label='Choose mixed voice source')
    with gr.Row():
        with gr.Column(scale=1,min_width=200):
            with gr.Row():
                mixed_voice_path = gr.Audio(label='Mixed voice', type='filepath')
        with gr.Column(scale=1,min_width=200):
            with gr.Row():
                ref_voice_path = gr.Audio(label='reference voice', type='filepath')
        with gr.Column(scale=1,min_width=200):
            with gr.Row():
                sep_voice = gr.Audio(label="Separate Voice")
            with gr.Row():
                btn = gr.Button("Separate voices", size='sm')
                btn.click(seprate, inputs=[mixed_voice_path, ref_voice_path, drop_down], outputs=sep_voice)
demo.launch()