File size: 4,285 Bytes
78e32cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import torchaudio
import torch
import numpy as np
import gradio as gr
import yaml
import librosa
import tqdm

import look2hear.models
from ml_collections import ConfigDict

def load_audio(file_path):
    audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
    print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
    #audio = dBgain(audio, -6)
    return torch.from_numpy(audio), samplerate


def get_config(config_path):
    with open(config_path) as f:
        #config = OmegaConf.load(config_path)
        config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
        return config




def _getWindowingArray(window_size, fade_size):
    # IMPORTANT NOTE :
    # no fades here in the end, only removing the failed ending of the chunk
    fadein = torch.linspace(1, 1, fade_size)
    fadeout = torch.linspace(0, 0, fade_size)
    window = torch.ones(window_size)
    window[-fade_size:] *= fadeout
    window[:fade_size] *= fadein
    return window



description = f'''
texts
'''


apollo_config = get_config('configs/apollo.yaml')
apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).cuda()

models = [
    ('MP3 restore', apollo_model)
]

@spaces.GPU
def enchance(model, audio):
    test_data, samplerate = load_audio(audio)
    C = 10 * samplerate  # chunk_size seconds to samples
    N = 2
    step = C // N
    fade_size = 3 * 44100 # 3 seconds
    print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
    
    border = C - step
    
    # handle mono inputs correctly
    if len(test_data.shape) == 1:
        test_data = test_data.unsqueeze(0) 

    # Pad the input if necessary
    if test_data.shape[1] > 2 * border and (border > 0):
        test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')

    windowingArray = _getWindowingArray(C, fade_size)

    result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
    counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)

    i = 0
    progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)

    while i < test_data.shape[1]:
        part = test_data[:, i:i + C]
        length = part.shape[-1]
        if length < C:
            if length > C // 2 + 1:
                part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
            else:
                part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)


        chunk = part.unsqueeze(0).cuda()
        with torch.no_grad():
            out = model(chunk).squeeze(0).squeeze(0).cpu()

        window = windowingArray
        if i == 0:  # First audio chunk, no fadein
            window[:fade_size] = 1
        elif i + C >= test_data.shape[1]:  # Last audio chunk, no fadeout
            window[-fade_size:] = 1

        result[..., i:i+length] += out[..., :length] * window[..., :length]
        counter[..., i:i+length] += window[..., :length]

        i += step
        progress_bar.update(step)

    progress_bar.close()

    final_output = result / counter
    final_output = final_output.squeeze(0).numpy()
    np.nan_to_num(final_output, copy=False, nan=0.0)

    # Remove padding if added earlier
    if test_data.shape[1] > 2 * border and (border > 0):
        final_output = final_output[..., border:-border]
    
    return samplerate, final_output.T


if __name__ == "__main__":
    i = gr.Interface(
        fn=enchance,
        description=description,
        inputs=[
            gr.Dropdown(label="Model", choices=models, value=models[0]),
            gr.Audio(label="Input Audio:", interactive=True, type='filepath', max_length=300, waveform_options={'waveform_progress_color': '#3C82F6'}),
        ],
        outputs=[
            gr.Audio(
                        label="Output Audio",
                        autoplay=False,
                        streaming=False,
                        type="numpy",
                    ),
            
        ],
        allow_flagging ='never',
        cache_examples=False,
        title='Enchanser',
        
    )
    i.queue(max_size=20, default_concurrency_limit=4)
    i.launch(share=False, server_name="0.0.0.0")