Gosula commited on
Commit
1819f26
Β·
verified Β·
1 Parent(s): 38203f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -58
app.py CHANGED
@@ -1,20 +1,15 @@
1
 
2
  import gradio as gr
3
  import peft
4
- from peft import LoraConfig, PeftModel
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
6
  import torch
7
- from PIL import Image
8
- import requests
9
- import numpy as np
10
  import torch.nn as nn
11
  import whisperx
12
- import ffmpeg, pydub
13
- from pydub import AudioSegment
14
 
15
  clip_model_name = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"
16
  phi_model_name = "microsoft/phi-2"
17
-
18
  tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
19
  processor = AutoProcessor.from_pretrained(clip_model_name)
20
  tokenizer.pad_token = tokenizer.eos_token
@@ -22,41 +17,49 @@ IMAGE_TOKEN_ID = 23893 # token for word comment
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  clip_embed = 640
24
  phi_embed = 2560
25
- compute_type = "float32"
26
- audio_batch_size = 1
27
-
28
- import gc
29
-
 
 
 
 
 
 
 
 
 
 
 
30
  # models
31
  clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
32
-
33
  projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
34
-
35
- gc.collect()
36
- phi_model = AutoModelForCausalLM.from_pretrained(
37
- phi_model_name,
38
- trust_remote_code=True,
39
- )
40
- audio_model = whisperx.load_model("small", device, compute_type=compute_type)
41
 
42
  # load weights
43
- model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/')
44
- merged_model = model_to_merge.merge_and_unload().to(device)
45
- projection.load_state_dict(torch.load('./ft_projection.pth',map_location=torch.device(device)))
 
46
 
47
- def inference(img=None,img_audio=None,val_q=None):
48
 
49
- max_generate_length = 50
50
  val_combined_embeds = []
51
-
52
  with torch.no_grad():
53
-
54
  # image
55
  if img is not None:
56
  image_processed = processor(images=img, return_tensors="pt").to(device)
57
  clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
58
  val_image_embeds = projection(clip_val_outputs)
59
-
 
60
  img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
61
  img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
62
 
@@ -65,64 +68,59 @@ def inference(img=None,img_audio=None,val_q=None):
65
 
66
  # audio
67
  if img_audio is not None:
68
-
69
- # accepting only initial few secs speech
70
- audio = AudioSegment.from_mp3( img_audio)
71
- clipped_audio = audio[:20*1000]
72
- clipped_audio.export( 'audio.mp3', format="mp3")
73
- result = audio_model.transcribe('audio.mp3')
74
  audio_text = ''
75
-
76
- audio_text = result["segments"][0]['text']
77
  audio_text = audio_text.strip()
78
  audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
79
  audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
80
  val_combined_embeds.append(audio_embeds)
81
-
82
  # text question
83
  if len(val_q) != 0:
84
  val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
85
  val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
86
  val_combined_embeds.append(val_q_embeds)
87
 
88
- # val_combined_emb
89
  val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
90
-
 
91
  predicted_caption = torch.full((1,max_generate_length),50256).to(device)
92
-
93
  for g in range(max_generate_length):
94
- phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits']
95
- predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1)
96
- predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
97
  predicted_caption[:,g] = predicted_word_token.view(1,-1)
98
- next_token_embeds = phi_model.model.embed_tokens(predicted_word_token)
99
  val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
100
-
101
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
102
-
103
  return predicted_captions_decoded
 
104
 
105
  with gr.Blocks() as demo:
106
 
107
  gr.Markdown(
108
  """
109
- # multi-modalLLM
110
- Build using Tiny Clip model and Microsoft's Phi-2 model fine tuned on Instruct 150k.
111
  """
112
  )
113
 
114
  # app GUI
115
  with gr.Row():
116
  with gr.Column():
117
- img_input = gr.Image(label='Reference Image',type="pil")
118
- img_question = gr.Text(label ='Question related to Image')
119
- img_audio = gr.Audio(label="Speak a question", sources=['microphone', 'upload'], type='filepath')
120
  with gr.Column():
121
- img_answer = gr.Text(label ='Response')
122
-
123
- section_btn = gr.Button("Process")
124
- section_btn.click(inference, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
125
 
 
 
 
126
  if __name__ == "__main__":
127
- demo.launch(debug=True)
128
-
 
1
 
2
  import gradio as gr
3
  import peft
4
+ from peft import LoraConfig
5
+ from transformers import AutoTokenizer,BitsAndBytesConfig, AutoModelForCausalLM, CLIPVisionModel, AutoProcessor
6
  import torch
7
+ from peft import PeftModel
 
 
8
  import torch.nn as nn
9
  import whisperx
 
 
10
 
11
  clip_model_name = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"
12
  phi_model_name = "microsoft/phi-2"
 
13
  tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
14
  processor = AutoProcessor.from_pretrained(clip_model_name)
15
  tokenizer.pad_token = tokenizer.eos_token
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  clip_embed = 640
19
  phi_embed = 2560
20
+ compute_type = "float16"
21
+ audio_batch_size = 16
22
+
23
+ class SimpleResBlock(nn.Module):
24
+ def __init__(self, phi_embed):
25
+ super().__init__()
26
+ self.pre_norm = nn.LayerNorm(phi_embed)
27
+ self.proj = nn.Sequential(
28
+ nn.Linear(phi_embed, phi_embed),
29
+ nn.GELU(),
30
+ nn.Linear(phi_embed, phi_embed)
31
+ )
32
+ def forward(self, x):
33
+ x = self.pre_norm(x)
34
+ return x + self.proj(x)
35
+
36
  # models
37
  clip_model = CLIPVisionModel.from_pretrained(clip_model_name).to(device)
 
38
  projection = torch.nn.Linear(clip_embed, phi_embed).to(device)
39
+ resblock = SimpleResBlock(phi_embed).to(device)
40
+ phi_model = AutoModelForCausalLM.from_pretrained(phi_model_name,trust_remote_code=True).to(device)
41
+ audio_model = whisperx.load_model("tiny", device, compute_type=compute_type)
 
 
 
 
42
 
43
  # load weights
44
+ model_to_merge = PeftModel.from_pretrained(phi_model,'./model_chkpt/qlora_adaptor')
45
+ merged_model = model_to_merge.merge_and_unload()
46
+ projection.load_state_dict(torch.load('./model_chkpt/ft_projection_layer.pth',map_location=torch.device(device)))
47
+ resblock.load_state_dict(torch.load('./model_chkpt/ft_projection_model.pth',map_location=torch.device(device)))
48
 
49
+ def model_generate_ans(img=None,img_audio=None,val_q=None):
50
 
51
+ max_generate_length = 100
52
  val_combined_embeds = []
53
+
54
  with torch.no_grad():
55
+
56
  # image
57
  if img is not None:
58
  image_processed = processor(images=img, return_tensors="pt").to(device)
59
  clip_val_outputs = clip_model(**image_processed).last_hidden_state[:,1:,:]
60
  val_image_embeds = projection(clip_val_outputs)
61
+ val_image_embeds = resblock(val_image_embeds).to(torch.float16)
62
+
63
  img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).to(device)
64
  img_token_embeds = merged_model.model.embed_tokens(img_token_tensor).unsqueeze(0).unsqueeze(0)
65
 
 
68
 
69
  # audio
70
  if img_audio is not None:
71
+ audio_result = audio_model.transcribe(img_audio)
 
 
 
 
 
72
  audio_text = ''
73
+ for seg in audio_result['segments']:
74
+ audio_text += seg['text']
75
  audio_text = audio_text.strip()
76
  audio_tokens = tokenizer(audio_text, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
77
  audio_embeds = merged_model.model.embed_tokens(audio_tokens).unsqueeze(0)
78
  val_combined_embeds.append(audio_embeds)
79
+
80
  # text question
81
  if len(val_q) != 0:
82
  val_q_tokenised = tokenizer(val_q, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0).to(device)
83
  val_q_embeds = merged_model.model.embed_tokens(val_q_tokenised).unsqueeze(0)
84
  val_combined_embeds.append(val_q_embeds)
85
 
 
86
  val_combined_embeds = torch.cat(val_combined_embeds,dim=1)
87
+
88
+ #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
89
  predicted_caption = torch.full((1,max_generate_length),50256).to(device)
90
+
91
  for g in range(max_generate_length):
92
+ phi_output_logits = merged_model(inputs_embeds=val_combined_embeds)['logits'] # 4, 69, 51200
93
+ predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
94
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
95
  predicted_caption[:,g] = predicted_word_token.view(1,-1)
96
+ next_token_embeds = phi_model.model.embed_tokens(predicted_word_token) # 4,1,2560
97
  val_combined_embeds = torch.cat([val_combined_embeds, next_token_embeds], dim=1)
98
+
99
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
100
+
101
  return predicted_captions_decoded
102
+
103
 
104
  with gr.Blocks() as demo:
105
 
106
  gr.Markdown(
107
  """
108
+ # Chat with MultiModal GPT !
109
+ Build using combining clip model and phi-2 model.
110
  """
111
  )
112
 
113
  # app GUI
114
  with gr.Row():
115
  with gr.Column():
116
+ img_input = gr.Image(label='Image',type="pil")
117
+ img_audio = gr.Audio(label="Audio Query", sources=['microphone', 'upload'], type='filepath')
118
+ img_question = gr.Text(label ='Text Query')
119
  with gr.Column():
120
+ img_answer = gr.Text(label ='Answer')
 
 
 
121
 
122
+ section_btn = gr.Button("Submit")
123
+ section_btn.click(model_generate_ans, inputs=[img_input,img_audio,img_question], outputs=[img_answer])
124
+
125
  if __name__ == "__main__":
126
+ demo.launch()