AmithAdiraju1694 commited on
Commit
271d9ed
·
verified ·
1 Parent(s): 470188c

refactor_inference (#1)

Browse files

- Modified inference order, moved transcription and summarization to separate functions and replaced regular pytorch model with onnx one for faster inference (2beef55857c2f71271d604afca602cc2e7376e50)

Files changed (5) hide show
  1. app.py +2 -1
  2. model_inference.py +30 -25
  3. pages.py +30 -24
  4. requirements.txt +1 -0
  5. video_rating_siamesev2.onnx +3 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import streamlit as st
2
  from streamlit import session_state as sst
3
  import asyncio
 
4
 
5
  from pages import landing_page, model_inference_page
6
 
@@ -9,7 +10,7 @@ if "page" not in sst:
9
 
10
  def reset_sst():
11
  for key in list(sst.keys()):
12
- if key != "page":
13
  sst.pop(key, None)
14
 
15
 
 
1
  import streamlit as st
2
  from streamlit import session_state as sst
3
  import asyncio
4
+ import torch
5
 
6
  from pages import landing_page, model_inference_page
7
 
 
10
 
11
  def reset_sst():
12
  for key in list(sst.keys()):
13
+ if key != "page" and key != 'device':
14
  sst.pop(key, None)
15
 
16
 
model_inference.py CHANGED
@@ -14,6 +14,8 @@ from utils import timer
14
  import numpy as np
15
  import whisper
16
  from utils import batch_generator, cosine_sim
 
 
17
 
18
 
19
 
@@ -67,13 +69,19 @@ class SiameseNetwork(nn.Module):
67
 
68
 
69
  @timer
70
- def summarize_from_audio(audio_tensor):
71
-
72
  # Transcribe the in-memory audio
73
- result = audio_transcriber_model.transcribe(audio_tensor)
 
 
74
  all_transcription_segments = result["text"]
 
 
 
 
75
 
76
- summary = text_summarizer(prompt_audio_summarization + all_transcription_segments,
77
  max_length=108,
78
  min_length=36, do_sample=False)[0]['summary_text']
79
 
@@ -125,44 +133,41 @@ def rate_video_frames(video_frames):
125
  Classifies video frames into another category.
126
  """
127
 
128
- tensor = torch.tensor(
129
- np.array(video_frames),
130
- dtype = torch.float32
131
- ).reshape(len(video_frames)//5, 5, 224,224,3) # 20,5,224,224,3
132
- video_frame_emb = video_rating_model.inference(tensor) # 20,128
133
 
 
 
134
  overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
135
- emb2 = video_frame_emb,
136
  threshold=0.4
137
  )
 
 
138
 
139
- if count_upg / (len(video_frames)//5) > 0.5:
140
  return f"Out of {len(video_frames)} important moments of this video, {count_upg*5} moments contain under or at least PG content. Hence this video is suitable for kids & family."
141
  else:
142
  return f"Out of {len(video_frames)} important moments of this video, {(len(video_frames)//5 - count_upg)*5} moments contain at least PG-13 content.Hence parental guidance is strongly suggested for this video."
143
 
144
  @st.cache_resource
145
  def load_models():
146
-
147
- transcriber = whisper.load_model("base")
148
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
149
 
150
  base_frame_emb = torch.tensor(
151
  np.load('base_frame_medoid.npz')['arr'],
152
- dtype = torch.float32
 
153
  )
154
 
155
-
156
- video_rating_model = SiameseNetwork()
157
- # video_rating_model.load_state_dict(
158
- # torch.load('/Users/amithadiraju/Desktop/Video_Summary_App/video_contrastive-siamese_v3.pt',
159
- # weights_only = True
160
- # )
161
- # )
162
- video_rating_model.eval()
163
-
164
  return (
165
- transcriber, summarizer, video_rating_model, base_frame_emb
166
  )
167
 
168
  audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models()
 
14
  import numpy as np
15
  import whisper
16
  from utils import batch_generator, cosine_sim
17
+ from streamlit import session_state as sst
18
+ import onnxruntime
19
 
20
 
21
 
 
69
 
70
 
71
  @timer
72
+ def get_text_from_audio(audio_tensors):
73
+ """Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
74
  # Transcribe the in-memory audio
75
+ audio_tensors = audio_tensors.to(sst['device'])
76
+ result = audio_transcriber_model.transcribe(audio_tensors
77
+ )
78
  all_transcription_segments = result["text"]
79
+ return all_transcription_segments
80
+
81
+ @timer
82
+ def summarize_from_text(raw_transcription):
83
 
84
+ summary = text_summarizer(prompt_audio_summarization + raw_transcription,
85
  max_length=108,
86
  min_length=36, do_sample=False)[0]['summary_text']
87
 
 
133
  Classifies video frames into another category.
134
  """
135
 
136
+ inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3
137
+ inputs_dict = {"frames": inp_frames}
 
 
 
138
 
139
+ video_frame_emb = video_rating_model.run(['emb'], inputs_dict)[0]
140
+
141
  overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
142
+ emb2 = torch.tensor(video_frame_emb),
143
  threshold=0.4
144
  )
145
+
146
+ perc_of_upg = count_upg / (len(video_frames)//5)
147
 
148
+ if perc_of_upg > 0.4:
149
  return f"Out of {len(video_frames)} important moments of this video, {count_upg*5} moments contain under or at least PG content. Hence this video is suitable for kids & family."
150
  else:
151
  return f"Out of {len(video_frames)} important moments of this video, {(len(video_frames)//5 - count_upg)*5} moments contain at least PG-13 content.Hence parental guidance is strongly suggested for this video."
152
 
153
  @st.cache_resource
154
  def load_models():
155
+ sst['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
156
+ transcriber = whisper.load_model("base", device = sst['device'])
157
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device = sst['device'])
158
 
159
  base_frame_emb = torch.tensor(
160
  np.load('base_frame_medoid.npz')['arr'],
161
+ dtype = torch.float32,
162
+ device = sst['device']
163
  )
164
 
165
+ session = onnxruntime.InferenceSession("video_rating_siamesev2.onnx",
166
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
167
+ )
168
+
 
 
 
 
 
169
  return (
170
+ transcriber, summarizer, session, base_frame_emb
171
  )
172
 
173
  audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models()
pages.py CHANGED
@@ -5,7 +5,7 @@ import time
5
  import pandas as pd
6
  from utils import navigate_to
7
 
8
- from model_inference import rate_video_frames,summarize_from_audio
9
  from utils import read_important_frames, extract_audio
10
  import numpy as np
11
 
@@ -61,28 +61,12 @@ async def landing_page():
61
 
62
 
63
  async def model_inference_page():
64
-
65
- df = pd.DataFrame([('Video_Text_Summary', 'Video_Rating_Scale')])
66
- sl_df = st.table(df)
67
-
68
- # check if audio is present and it's non-empty
69
- if "audio_transcript" in sst:
70
-
71
- video_summary_text = summarize_from_audio(sst["audio_transcript"])
72
-
73
- if len(video_summary_text) > 0:
74
- pass
75
- else:
76
- video_summary_text = "Sorry, we couldn't find any audio data from your video, hence couldn't generate any summary"
77
-
78
- print("Time taken to generate text summary from audio in seconds: ", summarize_from_audio.total_time)
79
-
80
 
81
  # check if frames are present and they are non-empty
82
  if "important_frames" in sst:
83
 
84
  important_frames = sst["important_frames"]
85
- with st.spinner("Generating text summary for your video"):
86
  video_rating_scale = rate_video_frames(important_frames)
87
 
88
  if len(video_rating_scale) > 0:
@@ -90,13 +74,35 @@ async def model_inference_page():
90
  else:
91
  video_rating_scale = "Sorry, we couldn't find any images from your video, hence couldn't generate any summary"
92
 
93
- print("Time taken to generate video rating in seconds: ", rate_video_frames.total_time)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- sl_df.add_rows(
96
-
97
- [( video_summary_text, video_rating_scale ) ]
98
-
99
- )
100
 
101
  st.button("Go Home",
102
  on_click = navigate_to,
 
5
  import pandas as pd
6
  from utils import navigate_to
7
 
8
+ from model_inference import rate_video_frames,get_text_from_audio, summarize_from_text
9
  from utils import read_important_frames, extract_audio
10
  import numpy as np
11
 
 
61
 
62
 
63
  async def model_inference_page():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  # check if frames are present and they are non-empty
66
  if "important_frames" in sst:
67
 
68
  important_frames = sst["important_frames"]
69
+ with st.spinner("Generating Movie Scale rating for your video"):
70
  video_rating_scale = rate_video_frames(important_frames)
71
 
72
  if len(video_rating_scale) > 0:
 
74
  else:
75
  video_rating_scale = "Sorry, we couldn't find any images from your video, hence couldn't generate any summary"
76
 
77
+ st.toast("Done")
78
+ st.header("Movie Scale Rating of Your Video: ", divider = True)
79
+ st.write(video_rating_scale)
80
+ st.markdown("************************")
81
+
82
+
83
+ # check if audio is present and it's non-empty
84
+ if "audio_transcript" in sst:
85
+
86
+ with st.spinner("Extracting text from audio file"):
87
+ video_raw_text = get_text_from_audio(sst["audio_transcript"])
88
+ st.toast("Done")
89
+
90
+ with st.spinner("Summarizing text from entire transcript"):
91
+ video_summary_text = summarize_from_text(video_raw_text)
92
+ st.toast("Done")
93
+
94
+
95
+ if len(video_summary_text) > 0:
96
+ pass
97
+ else:
98
+ video_summary_text = "Sorry, we couldn't find any audio data from your video, hence couldn't generate any summary"
99
+
100
+ print("Time taken to get raw text from audio in seconds: ", get_text_from_audio.total_time)
101
+ print("Time taken to generate text summary from raw text in seconds: ", summarize_from_text.total_time)
102
+
103
+ st.header("Audio Transcript summary of your video: ", divider = True)
104
+ st.write(video_summary_text)
105
 
 
 
 
 
 
106
 
107
  st.button("Go Home",
108
  on_click = navigate_to,
requirements.txt CHANGED
@@ -23,6 +23,7 @@ PyYAML==6.0.2
23
  safetensors==0.4.5
24
  scipy==1.13.1
25
  sentencepiece==0.2.0
 
26
  smmap==5.0.2
27
  sniffio==1.3.1
28
  soundfile==0.13.1
 
23
  safetensors==0.4.5
24
  scipy==1.13.1
25
  sentencepiece==0.2.0
26
+ onnxruntime-gpu==1.17.1
27
  smmap==5.0.2
28
  sniffio==1.3.1
29
  soundfile==0.13.1
video_rating_siamesev2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f073fb71728915d53cde75d316c3196c07bda3fe79af5acab6596bc397146b6
3
+ size 344064697