Spaces:
Runtime error
Runtime error
Fabrice-TIERCELIN
commited on
Add new code
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ from audioldm.audio.stft import TacotronSTFT
|
|
17 |
from audioldm.variational_autoencoder import AutoencoderKL
|
18 |
from pydub import AudioSegment
|
19 |
|
20 |
-
# Old
|
21 |
import numpy as np
|
22 |
import torch.nn.functional as F
|
23 |
from torchvision.transforms.functional import normalize
|
@@ -28,6 +28,144 @@ import PIL
|
|
28 |
from PIL import Image
|
29 |
from typing import Tuple
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
net=BriaRMBG()
|
32 |
# model_path = "./model1.pth"
|
33 |
#model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
|
|
|
17 |
from audioldm.variational_autoencoder import AutoencoderKL
|
18 |
from pydub import AudioSegment
|
19 |
|
20 |
+
# Old import
|
21 |
import numpy as np
|
22 |
import torch.nn.functional as F
|
23 |
from torchvision.transforms.functional import normalize
|
|
|
28 |
from PIL import Image
|
29 |
from typing import Tuple
|
30 |
|
31 |
+
max_64_bit_int = 2**63 - 1
|
32 |
+
|
33 |
+
# Automatic device detection
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
device_type = "cuda"
|
36 |
+
device_selection = "cuda:0"
|
37 |
+
else:
|
38 |
+
device_type = "cpu"
|
39 |
+
device_selection = "cpu"
|
40 |
+
|
41 |
+
class Tango:
|
42 |
+
def __init__(self, name = "declare-lab/tango2", device = device_selection):
|
43 |
+
|
44 |
+
path = snapshot_download(repo_id = name)
|
45 |
+
|
46 |
+
vae_config = json.load(open("{}/vae_config.json".format(path)))
|
47 |
+
stft_config = json.load(open("{}/stft_config.json".format(path)))
|
48 |
+
main_config = json.load(open("{}/main_config.json".format(path)))
|
49 |
+
|
50 |
+
self.vae = AutoencoderKL(**vae_config).to(device)
|
51 |
+
self.stft = TacotronSTFT(**stft_config).to(device)
|
52 |
+
self.model = AudioDiffusion(**main_config).to(device)
|
53 |
+
|
54 |
+
vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location = device)
|
55 |
+
stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location = device)
|
56 |
+
main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location = device)
|
57 |
+
|
58 |
+
self.vae.load_state_dict(vae_weights)
|
59 |
+
self.stft.load_state_dict(stft_weights)
|
60 |
+
self.model.load_state_dict(main_weights)
|
61 |
+
|
62 |
+
print ("Successfully loaded checkpoint from:", name)
|
63 |
+
|
64 |
+
self.vae.eval()
|
65 |
+
self.stft.eval()
|
66 |
+
self.model.eval()
|
67 |
+
|
68 |
+
self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder = "scheduler")
|
69 |
+
|
70 |
+
def chunks(self, lst, n):
|
71 |
+
# Yield successive n-sized chunks from a list
|
72 |
+
for i in range(0, len(lst), n):
|
73 |
+
yield lst[i:i + n]
|
74 |
+
|
75 |
+
def generate(self, prompt, steps = 100, guidance = 3, samples = 1, disable_progress = True):
|
76 |
+
# Generate audio for a single prompt string
|
77 |
+
with torch.no_grad():
|
78 |
+
latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress = disable_progress)
|
79 |
+
mel = self.vae.decode_first_stage(latents)
|
80 |
+
wave = self.vae.decode_to_waveform(mel)
|
81 |
+
return wave
|
82 |
+
|
83 |
+
def generate_for_batch(self, prompts, steps = 200, guidance = 3, samples = 1, batch_size = 8, disable_progress = True):
|
84 |
+
# Generate audio for a list of prompt strings
|
85 |
+
outputs = []
|
86 |
+
for k in tqdm(range(0, len(prompts), batch_size)):
|
87 |
+
batch = prompts[k: k + batch_size]
|
88 |
+
with torch.no_grad():
|
89 |
+
latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress = disable_progress)
|
90 |
+
mel = self.vae.decode_first_stage(latents)
|
91 |
+
wave = self.vae.decode_to_waveform(mel)
|
92 |
+
outputs += [item for item in wave]
|
93 |
+
if samples == 1:
|
94 |
+
return outputs
|
95 |
+
return list(self.chunks(outputs, samples))
|
96 |
+
|
97 |
+
# Initialize TANGO
|
98 |
+
|
99 |
+
tango = Tango(device = "cpu")
|
100 |
+
tango.vae.to(device_type)
|
101 |
+
tango.stft.to(device_type)
|
102 |
+
tango.model.to(device_type)
|
103 |
+
|
104 |
+
def update_seed(is_randomize_seed, seed):
|
105 |
+
if is_randomize_seed:
|
106 |
+
return random.randint(0, max_64_bit_int)
|
107 |
+
return seed
|
108 |
+
|
109 |
+
def check(
|
110 |
+
prompt,
|
111 |
+
output_number,
|
112 |
+
steps,
|
113 |
+
guidance,
|
114 |
+
is_randomize_seed,
|
115 |
+
seed
|
116 |
+
):
|
117 |
+
if prompt is None or prompt == "":
|
118 |
+
raise gr.Error("Please provide a prompt input.")
|
119 |
+
if not output_number in [1, 2, 3]:
|
120 |
+
raise gr.Error("Please ask for 1, 2 or 3 output files.")
|
121 |
+
|
122 |
+
def update_output(output_format, output_number):
|
123 |
+
return [
|
124 |
+
gr.update(format = output_format),
|
125 |
+
gr.update(format = output_format, visible = (2 <= output_number)),
|
126 |
+
gr.update(format = output_format, visible = (output_number == 3)),
|
127 |
+
gr.update(visible = False)
|
128 |
+
]
|
129 |
+
|
130 |
+
def text2audio(
|
131 |
+
prompt,
|
132 |
+
output_number,
|
133 |
+
steps,
|
134 |
+
guidance,
|
135 |
+
is_randomize_seed,
|
136 |
+
seed
|
137 |
+
):
|
138 |
+
start = time.time()
|
139 |
+
|
140 |
+
if seed is None:
|
141 |
+
seed = random.randint(0, max_64_bit_int)
|
142 |
+
|
143 |
+
random.seed(seed)
|
144 |
+
torch.manual_seed(seed)
|
145 |
+
|
146 |
+
output_wave = tango.generate(prompt, steps, guidance, output_number)
|
147 |
+
|
148 |
+
output_wave_1 = gr.make_waveform((16000, output_wave[0]))
|
149 |
+
output_wave_2 = gr.make_waveform((16000, output_wave[1])) if (2 <= output_number) else None
|
150 |
+
output_wave_3 = gr.make_waveform((16000, output_wave[2])) if (output_number == 3) else None
|
151 |
+
|
152 |
+
end = time.time()
|
153 |
+
secondes = int(end - start)
|
154 |
+
minutes = secondes // 60
|
155 |
+
secondes = secondes - (minutes * 60)
|
156 |
+
hours = minutes // 60
|
157 |
+
minutes = minutes - (hours * 60)
|
158 |
+
return [
|
159 |
+
output_wave_1,
|
160 |
+
output_wave_2,
|
161 |
+
output_wave_3,
|
162 |
+
gr.update(visible = True, value = "Start again to get a different result. The output have been generated in " + ((str(hours) + " h, ") if hours != 0 else "") + ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + str(secondes) + " sec.")
|
163 |
+
]
|
164 |
+
|
165 |
+
if is_space_imported:
|
166 |
+
text2audio = spaces.GPU(text2audio, duration = 420)
|
167 |
+
|
168 |
+
# Old code
|
169 |
net=BriaRMBG()
|
170 |
# model_path = "./model1.pth"
|
171 |
#model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
|