Atin Sakkeer Hussain commited on
Commit
c399026
·
1 Parent(s): b5e6f78

Add app.py

Browse files
Files changed (3) hide show
  1. app.py +382 -0
  2. bot.png +0 -0
  3. user.png +0 -0
app.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.cuda
2
+
3
+ import gradio as gr
4
+ import mdtex2html
5
+ import tempfile
6
+ from PIL import Image
7
+ import scipy
8
+ import argparse
9
+
10
+ from llama.m2ugen import M2UGen
11
+ import llama
12
+ import numpy as np
13
+ import os
14
+ import torch
15
+ import torchaudio
16
+ import torchvision.transforms as transforms
17
+ import av
18
+ import subprocess
19
+ import librosa
20
+
21
+ parser = argparse.ArgumentParser()
22
+ parser.add_argument(
23
+ "--model", default="./ckpts/checkpoint.pth", type=str,
24
+ help="Name of or path to M2UGen pretrained checkpoint",
25
+ )
26
+ parser.add_argument(
27
+ "--llama_type", default="7B", type=str,
28
+ help="Type of llama original weight",
29
+ )
30
+ parser.add_argument(
31
+ "--llama_dir", default="/path/to/llama", type=str,
32
+ help="Path to LLaMA pretrained checkpoint",
33
+ )
34
+ parser.add_argument(
35
+ "--mert_path", default="m-a-p/MERT-v1-330M", type=str,
36
+ help="Path to MERT pretrained checkpoint",
37
+ )
38
+ parser.add_argument(
39
+ "--vit_path", default="m-a-p/MERT-v1-330M", type=str,
40
+ help="Path to ViT pretrained checkpoint",
41
+ )
42
+ parser.add_argument(
43
+ "--vivit_path", default="m-a-p/MERT-v1-330M", type=str,
44
+ help="Path to ViViT pretrained checkpoint",
45
+ )
46
+ parser.add_argument(
47
+ "--knn_dir", default="./ckpts", type=str,
48
+ help="Path to directory with KNN Index",
49
+ )
50
+ parser.add_argument(
51
+ '--music_decoder', default="musicgen", type=str,
52
+ help='Decoder to use musicgen/audioldm2')
53
+
54
+ parser.add_argument(
55
+ '--music_decoder_path', default="facebook/musicgen-medium", type=str,
56
+ help='Path to decoder to use musicgen/audioldm2')
57
+
58
+ args = parser.parse_args()
59
+
60
+ generated_audio_files = []
61
+
62
+ llama_type = args.llama_type
63
+ llama_ckpt_dir = os.path.join(args.llama_dir, llama_type)
64
+ llama_tokenzier_path = args.llama_dir
65
+ model = M2UGen(llama_ckpt_dir, llama_tokenzier_path, args, knn=False, stage=None, load_llama=False)
66
+
67
+ print("Loading Model Checkpoint")
68
+ checkpoint = torch.load(args.model, map_location='cpu')
69
+
70
+ new_ckpt = {}
71
+ for key, value in checkpoint['model'].items():
72
+ if "generation_model" in key:
73
+ continue
74
+ key = key.replace("module.", "")
75
+ new_ckpt[key] = value
76
+
77
+ load_result = model.load_state_dict(new_ckpt, strict=False)
78
+ assert len(load_result.unexpected_keys) == 0, f"Unexpected keys: {load_result.unexpected_keys}"
79
+ model.eval()
80
+ model.to("cuda")
81
+ #model.generation_model.to("cuda")
82
+ #model.mert_model.to("cuda")
83
+ #model.vit_model.to("cuda")
84
+ #model.vivit_model.to("cuda")
85
+
86
+ transform = transforms.Compose(
87
+ [transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0) == 1 else x)])
88
+
89
+
90
+ def postprocess(self, y):
91
+ if y is None:
92
+ return []
93
+ for i, (message, response) in enumerate(y):
94
+ y[i] = (
95
+ None if message is None else mdtex2html.convert((message)),
96
+ None if response is None else mdtex2html.convert(response),
97
+ )
98
+ return y
99
+
100
+
101
+ gr.Chatbot.postprocess = postprocess
102
+
103
+
104
+ def parse_text(text, image_path, video_path, audio_path):
105
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
106
+ outputs = text
107
+ lines = text.split("\n")
108
+ lines = [line for line in lines if line != ""]
109
+ count = 0
110
+ for i, line in enumerate(lines):
111
+ if "```" in line:
112
+ count += 1
113
+ items = line.split('`')
114
+ if count % 2 == 1:
115
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
116
+ else:
117
+ lines[i] = f'<br></code></pre>'
118
+ else:
119
+ if i > 0:
120
+ if count % 2 == 1:
121
+ line = line.replace("`", "\`")
122
+ line = line.replace("<", "&lt;")
123
+ line = line.replace(">", "&gt;")
124
+ line = line.replace(" ", "&nbsp;")
125
+ line = line.replace("*", "&ast;")
126
+ line = line.replace("_", "&lowbar;")
127
+ line = line.replace("-", "&#45;")
128
+ line = line.replace(".", "&#46;")
129
+ line = line.replace("!", "&#33;")
130
+ line = line.replace("(", "&#40;")
131
+ line = line.replace(")", "&#41;")
132
+ line = line.replace("$", "&#36;")
133
+ lines[i] = "<br>" + line
134
+ text = "".join(lines) + "<br>"
135
+ if image_path is not None:
136
+ text += f'<img src="./file={image_path}" style="display: inline-block;"><br>'
137
+ outputs = f'<Image>{image_path}</Image> ' + outputs
138
+ if video_path is not None:
139
+ text += f' <video controls playsinline height="320" width="240" style="display: inline-block;" src="./file={video_path}"></video6><br>'
140
+ outputs = f'<Video>{video_path}</Video> ' + outputs
141
+ if audio_path is not None:
142
+ text += f'<audio controls playsinline><source src="./file={audio_path}" type="audio/wav"></audio><br>'
143
+ outputs = f'<Audio>{audio_path}</Audio> ' + outputs
144
+ # text = text[::-1].replace(">rb<", "", 1)[::-1]
145
+ text = text[:-len("<br>")].rstrip() if text.endswith("<br>") else text
146
+ return text, outputs
147
+
148
+
149
+ def save_audio_to_local(audio, sec):
150
+ global generated_audio_files
151
+ if not os.path.exists('temp'):
152
+ os.mkdir('temp')
153
+ filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.wav')
154
+ if args.music_decoder == "audioldm2":
155
+ scipy.io.wavfile.write(filename, rate=16000, data=audio[0])
156
+ else:
157
+ scipy.io.wavfile.write(filename, rate=model.generation_model.config.audio_encoder.sampling_rate, data=audio)
158
+ generated_audio_files.append(filename)
159
+ return filename
160
+
161
+
162
+ def parse_reponse(model_outputs, audio_length_in_s):
163
+ response = ''
164
+ text_outputs = []
165
+ for output_i, p in enumerate(model_outputs):
166
+ if isinstance(p, str):
167
+ response += p.replace(' '.join([f'[AUD{i}]' for i in range(8)]), '')
168
+ response += '<br>'
169
+ text_outputs.append(p.replace(' '.join([f'[AUD{i}]' for i in range(8)]), ''))
170
+ elif 'aud' in p.keys():
171
+ _temp_output = ''
172
+ for idx, m in enumerate(p['aud']):
173
+ if isinstance(m, str):
174
+ response += m.replace(' '.join([f'[AUD{i}]' for i in range(8)]), '')
175
+ response += '<br>'
176
+ _temp_output += m.replace(' '.join([f'[AUD{i}]' for i in range(8)]), '')
177
+ else:
178
+ filename = save_audio_to_local(m, audio_length_in_s)
179
+ print(filename)
180
+ _temp_output = f'<Audio>{filename}</Audio> ' + _temp_output
181
+ response += f'<audio controls playsinline><source src="./file={filename}" type="audio/wav"></audio>'
182
+ text_outputs.append(_temp_output)
183
+ else:
184
+ pass
185
+ response = response[:-len("<br>")].rstrip() if response.endswith("<br>") else response
186
+ return response, text_outputs
187
+
188
+
189
+ def reset_user_input():
190
+ return gr.update(value='')
191
+
192
+
193
+ def reset_dialog():
194
+ return [], []
195
+
196
+
197
+ def reset_state():
198
+ global generated_audio_files
199
+ generated_audio_files = []
200
+ return None, None, None, None, [], [], []
201
+
202
+
203
+ def upload_image(conversation, chat_history, image_input):
204
+ input_image = Image.open(image_input.name).resize(
205
+ (224, 224)).convert('RGB')
206
+ input_image.save(image_input.name) # Overwrite with smaller image.
207
+ conversation += [(f'<img src="./file={image_input.name}" style="display: inline-block;">', "")]
208
+ return conversation, chat_history + [input_image, ""]
209
+
210
+
211
+ def read_video_pyav(container, indices):
212
+ frames = []
213
+ container.seek(0)
214
+ for i, frame in enumerate(container.decode(video=0)):
215
+ frames.append(frame)
216
+ chosen_frames = []
217
+ for i in indices:
218
+ chosen_frames.append(frames[i])
219
+ return np.stack([x.to_ndarray(format="rgb24") for x in chosen_frames])
220
+
221
+
222
+ def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
223
+ converted_len = int(clip_len * frame_sample_rate)
224
+ if converted_len > seg_len:
225
+ converted_len = 0
226
+ end_idx = np.random.randint(converted_len, seg_len)
227
+ start_idx = end_idx - converted_len
228
+ indices = np.linspace(start_idx, end_idx, num=clip_len)
229
+ indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
230
+ return indices
231
+
232
+
233
+ def get_video_length(filename):
234
+ print("Getting Video Length")
235
+ result = subprocess.run(["ffprobe", "-v", "error", "-show_entries",
236
+ "format=duration", "-of",
237
+ "default=noprint_wrappers=1:nokey=1", filename],
238
+ stdout=subprocess.PIPE,
239
+ stderr=subprocess.STDOUT)
240
+ return int(round(float(result.stdout)))
241
+
242
+
243
+ def get_audio_length(filename):
244
+ return int(round(librosa.get_duration(path=filename)))
245
+
246
+
247
+ def predict(
248
+ prompt_input,
249
+ image_path,
250
+ audio_path,
251
+ video_path,
252
+ chatbot,
253
+ top_p,
254
+ temperature,
255
+ history,
256
+ modality_cache,
257
+ audio_length_in_s):
258
+ global generated_audio_files
259
+ prompts = [llama.format_prompt(prompt_input)]
260
+ prompts = [model.tokenizer(x).input_ids for x in prompts]
261
+ print(image_path, audio_path, video_path)
262
+ image, audio, video = None, None, None
263
+ if image_path is not None:
264
+ image = transform(Image.open(image_path))
265
+ if audio_path is not None:
266
+ sample_rate = 24000
267
+ waveform, sr = torchaudio.load(audio_path)
268
+ if sample_rate != sr:
269
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
270
+ audio = torch.mean(waveform, 0)
271
+ if video_path is not None:
272
+ print("Opening Video")
273
+ container = av.open(video_path)
274
+ indices = sample_frame_indices(clip_len=32, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
275
+ video = read_video_pyav(container=container, indices=indices)
276
+
277
+ if len(generated_audio_files) != 0:
278
+ audio_length_in_s = get_audio_length(generated_audio_files[-1])
279
+ sample_rate = 24000
280
+ waveform, sr = torchaudio.load(generated_audio_files[-1])
281
+ if sample_rate != sr:
282
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=sample_rate)
283
+ audio = torch.mean(waveform, 0)
284
+ audio_length_in_s = int(len(audio)//sample_rate)
285
+ print(f"Audio Length: {audio_length_in_s}")
286
+ if video_path is not None:
287
+ audio_length_in_s = get_video_length(video_path)
288
+ print(f"Video Length: {audio_length_in_s}")
289
+ if audio_path is not None:
290
+ audio_length_in_s = get_audio_length(audio_path)
291
+ generated_audio_files.append(audio_path)
292
+ print(f"Audio Length: {audio_length_in_s}")
293
+
294
+ print(image, video, audio)
295
+ response = model.generate(prompts, audio, image, video, 200, temperature, top_p,
296
+ audio_length_in_s=audio_length_in_s)
297
+ print(response)
298
+ response_chat, response_outputs = parse_reponse(response, audio_length_in_s)
299
+ print('text_outputs: ', response_outputs)
300
+ user_chat, user_outputs = parse_text(prompt_input, image_path, video_path, audio_path)
301
+ chatbot.append((user_chat, response_chat))
302
+ history.append((user_outputs, ''.join(response_outputs).replace('\n###', '')))
303
+ return chatbot, history, modality_cache, None, None, None,
304
+
305
+
306
+ with gr.Blocks() as demo:
307
+ gr.HTML("""
308
+ <h1 align="center" style=" display: flex; flex-direction: row; justify-content: center; font-size: 25pt; "><img src='./file=bot.png' width="50" height="50" style="margin-right: 10px;">M<sup style="line-height: 200%; font-size: 60%">2</sup>UGen</h1>
309
+ <h3>This is the demo page of M<sup>2</sup>UGen, a Multimodal LLM capable of Music Understanding and Generation!</h3>
310
+ <div style="display: flex;"><a href='https://arxiv.org/pdf/2311.11255.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>
311
+ """)
312
+
313
+ with gr.Row():
314
+ with gr.Column(scale=0.7, min_width=500):
315
+ with gr.Row():
316
+ chatbot = gr.Chatbot(label='M2UGen Chatbot', avatar_images=(
317
+ (os.path.join(os.path.dirname(__file__), 'user.png')),
318
+ (os.path.join(os.path.dirname(__file__), "bot.png")))).style(height=440)
319
+
320
+ with gr.Tab("User Input"):
321
+ with gr.Row(scale=3):
322
+ user_input = gr.Textbox(label="Text", placeholder="Key in something here...", lines=3)
323
+ with gr.Row(scale=3):
324
+ with gr.Column(scale=1):
325
+ # image_btn = gr.UploadButton("🖼️ Upload Image", file_types=["image"])
326
+ image_path = gr.Image(type="filepath",
327
+ label="Image") # .style(height=200) # <PIL.Image.Image image mode=RGB size=512x512 at 0x7F6E06738D90>
328
+ with gr.Column(scale=1):
329
+ audio_path = gr.Audio(type='filepath') # .style(height=200)
330
+ with gr.Column(scale=1):
331
+ video_path = gr.Video() # .style(height=200) # , value=None, interactive=True
332
+ with gr.Column(scale=0.3, min_width=300):
333
+ with gr.Group():
334
+ with gr.Accordion('Text Advanced Options', open=True):
335
+ top_p = gr.Slider(0, 1, value=0.8, step=0.01, label="Top P", interactive=True)
336
+ temperature = gr.Slider(0, 1, value=0.6, step=0.01, label="Temperature", interactive=True)
337
+ with gr.Accordion('Audio Advanced Options', open=False):
338
+ audio_length_in_s = gr.Slider(5, 30, value=30, step=1, label="The audio length in seconds",
339
+ interactive=True)
340
+ with gr.Tab("Operation"):
341
+ with gr.Row(scale=1):
342
+ submitBtn = gr.Button(value="Submit & Run", variant="primary")
343
+ with gr.Row(scale=1):
344
+ emptyBtn = gr.Button("Clear History")
345
+
346
+ history = gr.State([])
347
+ modality_cache = gr.State([])
348
+
349
+ submitBtn.click(
350
+ predict, [
351
+ user_input,
352
+ image_path,
353
+ audio_path,
354
+ video_path,
355
+ chatbot,
356
+ top_p,
357
+ temperature,
358
+ history,
359
+ modality_cache,
360
+ audio_length_in_s
361
+ ], [
362
+ chatbot,
363
+ history,
364
+ modality_cache,
365
+ image_path,
366
+ audio_path,
367
+ video_path
368
+ ],
369
+ show_progress=True
370
+ )
371
+
372
+ submitBtn.click(reset_user_input, [], [user_input])
373
+ emptyBtn.click(reset_state, outputs=[
374
+ image_path,
375
+ audio_path,
376
+ video_path,
377
+ chatbot,
378
+ history,
379
+ modality_cache
380
+ ], show_progress=True)
381
+
382
+ demo.queue().launch(share=True, inbrowser=True, server_name='0.0.0.0', server_port=24000)
bot.png ADDED
user.png ADDED