Fabrice-TIERCELIN commited on
Commit
fc362b8
·
verified ·
1 Parent(s): 0ff2686

Add new code

Browse files
Files changed (1) hide show
  1. app.py +139 -1
app.py CHANGED
@@ -17,7 +17,7 @@ from audioldm.audio.stft import TacotronSTFT
17
  from audioldm.variational_autoencoder import AutoencoderKL
18
  from pydub import AudioSegment
19
 
20
- # Old code
21
  import numpy as np
22
  import torch.nn.functional as F
23
  from torchvision.transforms.functional import normalize
@@ -28,6 +28,144 @@ import PIL
28
  from PIL import Image
29
  from typing import Tuple
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  net=BriaRMBG()
32
  # model_path = "./model1.pth"
33
  #model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
 
17
  from audioldm.variational_autoencoder import AutoencoderKL
18
  from pydub import AudioSegment
19
 
20
+ # Old import
21
  import numpy as np
22
  import torch.nn.functional as F
23
  from torchvision.transforms.functional import normalize
 
28
  from PIL import Image
29
  from typing import Tuple
30
 
31
+ max_64_bit_int = 2**63 - 1
32
+
33
+ # Automatic device detection
34
+ if torch.cuda.is_available():
35
+ device_type = "cuda"
36
+ device_selection = "cuda:0"
37
+ else:
38
+ device_type = "cpu"
39
+ device_selection = "cpu"
40
+
41
+ class Tango:
42
+ def __init__(self, name = "declare-lab/tango2", device = device_selection):
43
+
44
+ path = snapshot_download(repo_id = name)
45
+
46
+ vae_config = json.load(open("{}/vae_config.json".format(path)))
47
+ stft_config = json.load(open("{}/stft_config.json".format(path)))
48
+ main_config = json.load(open("{}/main_config.json".format(path)))
49
+
50
+ self.vae = AutoencoderKL(**vae_config).to(device)
51
+ self.stft = TacotronSTFT(**stft_config).to(device)
52
+ self.model = AudioDiffusion(**main_config).to(device)
53
+
54
+ vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location = device)
55
+ stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location = device)
56
+ main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location = device)
57
+
58
+ self.vae.load_state_dict(vae_weights)
59
+ self.stft.load_state_dict(stft_weights)
60
+ self.model.load_state_dict(main_weights)
61
+
62
+ print ("Successfully loaded checkpoint from:", name)
63
+
64
+ self.vae.eval()
65
+ self.stft.eval()
66
+ self.model.eval()
67
+
68
+ self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder = "scheduler")
69
+
70
+ def chunks(self, lst, n):
71
+ # Yield successive n-sized chunks from a list
72
+ for i in range(0, len(lst), n):
73
+ yield lst[i:i + n]
74
+
75
+ def generate(self, prompt, steps = 100, guidance = 3, samples = 1, disable_progress = True):
76
+ # Generate audio for a single prompt string
77
+ with torch.no_grad():
78
+ latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress = disable_progress)
79
+ mel = self.vae.decode_first_stage(latents)
80
+ wave = self.vae.decode_to_waveform(mel)
81
+ return wave
82
+
83
+ def generate_for_batch(self, prompts, steps = 200, guidance = 3, samples = 1, batch_size = 8, disable_progress = True):
84
+ # Generate audio for a list of prompt strings
85
+ outputs = []
86
+ for k in tqdm(range(0, len(prompts), batch_size)):
87
+ batch = prompts[k: k + batch_size]
88
+ with torch.no_grad():
89
+ latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress = disable_progress)
90
+ mel = self.vae.decode_first_stage(latents)
91
+ wave = self.vae.decode_to_waveform(mel)
92
+ outputs += [item for item in wave]
93
+ if samples == 1:
94
+ return outputs
95
+ return list(self.chunks(outputs, samples))
96
+
97
+ # Initialize TANGO
98
+
99
+ tango = Tango(device = "cpu")
100
+ tango.vae.to(device_type)
101
+ tango.stft.to(device_type)
102
+ tango.model.to(device_type)
103
+
104
+ def update_seed(is_randomize_seed, seed):
105
+ if is_randomize_seed:
106
+ return random.randint(0, max_64_bit_int)
107
+ return seed
108
+
109
+ def check(
110
+ prompt,
111
+ output_number,
112
+ steps,
113
+ guidance,
114
+ is_randomize_seed,
115
+ seed
116
+ ):
117
+ if prompt is None or prompt == "":
118
+ raise gr.Error("Please provide a prompt input.")
119
+ if not output_number in [1, 2, 3]:
120
+ raise gr.Error("Please ask for 1, 2 or 3 output files.")
121
+
122
+ def update_output(output_format, output_number):
123
+ return [
124
+ gr.update(format = output_format),
125
+ gr.update(format = output_format, visible = (2 <= output_number)),
126
+ gr.update(format = output_format, visible = (output_number == 3)),
127
+ gr.update(visible = False)
128
+ ]
129
+
130
+ def text2audio(
131
+ prompt,
132
+ output_number,
133
+ steps,
134
+ guidance,
135
+ is_randomize_seed,
136
+ seed
137
+ ):
138
+ start = time.time()
139
+
140
+ if seed is None:
141
+ seed = random.randint(0, max_64_bit_int)
142
+
143
+ random.seed(seed)
144
+ torch.manual_seed(seed)
145
+
146
+ output_wave = tango.generate(prompt, steps, guidance, output_number)
147
+
148
+ output_wave_1 = gr.make_waveform((16000, output_wave[0]))
149
+ output_wave_2 = gr.make_waveform((16000, output_wave[1])) if (2 <= output_number) else None
150
+ output_wave_3 = gr.make_waveform((16000, output_wave[2])) if (output_number == 3) else None
151
+
152
+ end = time.time()
153
+ secondes = int(end - start)
154
+ minutes = secondes // 60
155
+ secondes = secondes - (minutes * 60)
156
+ hours = minutes // 60
157
+ minutes = minutes - (hours * 60)
158
+ return [
159
+ output_wave_1,
160
+ output_wave_2,
161
+ output_wave_3,
162
+ gr.update(visible = True, value = "Start again to get a different result. The output have been generated in " + ((str(hours) + " h, ") if hours != 0 else "") + ((str(minutes) + " min, ") if hours != 0 or minutes != 0 else "") + str(secondes) + " sec.")
163
+ ]
164
+
165
+ if is_space_imported:
166
+ text2audio = spaces.GPU(text2audio, duration = 420)
167
+
168
+ # Old code
169
  net=BriaRMBG()
170
  # model_path = "./model1.pth"
171
  #model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')