refactor_inference (#1)
Browse files- Modified inference order, moved transcription and summarization to separate functions and replaced regular pytorch model with onnx one for faster inference (2beef55857c2f71271d604afca602cc2e7376e50)
- app.py +2 -1
- model_inference.py +30 -25
- pages.py +30 -24
- requirements.txt +1 -0
- video_rating_siamesev2.onnx +3 -0
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import streamlit as st
|
2 |
from streamlit import session_state as sst
|
3 |
import asyncio
|
|
|
4 |
|
5 |
from pages import landing_page, model_inference_page
|
6 |
|
@@ -9,7 +10,7 @@ if "page" not in sst:
|
|
9 |
|
10 |
def reset_sst():
|
11 |
for key in list(sst.keys()):
|
12 |
-
if key != "page":
|
13 |
sst.pop(key, None)
|
14 |
|
15 |
|
|
|
1 |
import streamlit as st
|
2 |
from streamlit import session_state as sst
|
3 |
import asyncio
|
4 |
+
import torch
|
5 |
|
6 |
from pages import landing_page, model_inference_page
|
7 |
|
|
|
10 |
|
11 |
def reset_sst():
|
12 |
for key in list(sst.keys()):
|
13 |
+
if key != "page" and key != 'device':
|
14 |
sst.pop(key, None)
|
15 |
|
16 |
|
model_inference.py
CHANGED
@@ -14,6 +14,8 @@ from utils import timer
|
|
14 |
import numpy as np
|
15 |
import whisper
|
16 |
from utils import batch_generator, cosine_sim
|
|
|
|
|
17 |
|
18 |
|
19 |
|
@@ -67,13 +69,19 @@ class SiameseNetwork(nn.Module):
|
|
67 |
|
68 |
|
69 |
@timer
|
70 |
-
def
|
71 |
-
|
72 |
# Transcribe the in-memory audio
|
73 |
-
|
|
|
|
|
74 |
all_transcription_segments = result["text"]
|
|
|
|
|
|
|
|
|
75 |
|
76 |
-
summary = text_summarizer(prompt_audio_summarization +
|
77 |
max_length=108,
|
78 |
min_length=36, do_sample=False)[0]['summary_text']
|
79 |
|
@@ -125,44 +133,41 @@ def rate_video_frames(video_frames):
|
|
125 |
Classifies video frames into another category.
|
126 |
"""
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
dtype = torch.float32
|
131 |
-
).reshape(len(video_frames)//5, 5, 224,224,3) # 20,5,224,224,3
|
132 |
-
video_frame_emb = video_rating_model.inference(tensor) # 20,128
|
133 |
|
|
|
|
|
134 |
overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
|
135 |
-
emb2 = video_frame_emb,
|
136 |
threshold=0.4
|
137 |
)
|
|
|
|
|
138 |
|
139 |
-
if
|
140 |
return f"Out of {len(video_frames)} important moments of this video, {count_upg*5} moments contain under or at least PG content. Hence this video is suitable for kids & family."
|
141 |
else:
|
142 |
return f"Out of {len(video_frames)} important moments of this video, {(len(video_frames)//5 - count_upg)*5} moments contain at least PG-13 content.Hence parental guidance is strongly suggested for this video."
|
143 |
|
144 |
@st.cache_resource
|
145 |
def load_models():
|
146 |
-
|
147 |
-
transcriber = whisper.load_model("base")
|
148 |
-
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
|
149 |
|
150 |
base_frame_emb = torch.tensor(
|
151 |
np.load('base_frame_medoid.npz')['arr'],
|
152 |
-
dtype = torch.float32
|
|
|
153 |
)
|
154 |
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
# weights_only = True
|
160 |
-
# )
|
161 |
-
# )
|
162 |
-
video_rating_model.eval()
|
163 |
-
|
164 |
return (
|
165 |
-
transcriber, summarizer,
|
166 |
)
|
167 |
|
168 |
audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models()
|
|
|
14 |
import numpy as np
|
15 |
import whisper
|
16 |
from utils import batch_generator, cosine_sim
|
17 |
+
from streamlit import session_state as sst
|
18 |
+
import onnxruntime
|
19 |
|
20 |
|
21 |
|
|
|
69 |
|
70 |
|
71 |
@timer
|
72 |
+
def get_text_from_audio(audio_tensors):
|
73 |
+
"""Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
|
74 |
# Transcribe the in-memory audio
|
75 |
+
audio_tensors = audio_tensors.to(sst['device'])
|
76 |
+
result = audio_transcriber_model.transcribe(audio_tensors
|
77 |
+
)
|
78 |
all_transcription_segments = result["text"]
|
79 |
+
return all_transcription_segments
|
80 |
+
|
81 |
+
@timer
|
82 |
+
def summarize_from_text(raw_transcription):
|
83 |
|
84 |
+
summary = text_summarizer(prompt_audio_summarization + raw_transcription,
|
85 |
max_length=108,
|
86 |
min_length=36, do_sample=False)[0]['summary_text']
|
87 |
|
|
|
133 |
Classifies video frames into another category.
|
134 |
"""
|
135 |
|
136 |
+
inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3
|
137 |
+
inputs_dict = {"frames": inp_frames}
|
|
|
|
|
|
|
138 |
|
139 |
+
video_frame_emb = video_rating_model.run(['emb'], inputs_dict)[0]
|
140 |
+
|
141 |
overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
|
142 |
+
emb2 = torch.tensor(video_frame_emb),
|
143 |
threshold=0.4
|
144 |
)
|
145 |
+
|
146 |
+
perc_of_upg = count_upg / (len(video_frames)//5)
|
147 |
|
148 |
+
if perc_of_upg > 0.4:
|
149 |
return f"Out of {len(video_frames)} important moments of this video, {count_upg*5} moments contain under or at least PG content. Hence this video is suitable for kids & family."
|
150 |
else:
|
151 |
return f"Out of {len(video_frames)} important moments of this video, {(len(video_frames)//5 - count_upg)*5} moments contain at least PG-13 content.Hence parental guidance is strongly suggested for this video."
|
152 |
|
153 |
@st.cache_resource
|
154 |
def load_models():
|
155 |
+
sst['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
|
156 |
+
transcriber = whisper.load_model("base", device = sst['device'])
|
157 |
+
summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device = sst['device'])
|
158 |
|
159 |
base_frame_emb = torch.tensor(
|
160 |
np.load('base_frame_medoid.npz')['arr'],
|
161 |
+
dtype = torch.float32,
|
162 |
+
device = sst['device']
|
163 |
)
|
164 |
|
165 |
+
session = onnxruntime.InferenceSession("video_rating_siamesev2.onnx",
|
166 |
+
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
|
167 |
+
)
|
168 |
+
|
|
|
|
|
|
|
|
|
|
|
169 |
return (
|
170 |
+
transcriber, summarizer, session, base_frame_emb
|
171 |
)
|
172 |
|
173 |
audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models()
|
pages.py
CHANGED
@@ -5,7 +5,7 @@ import time
|
|
5 |
import pandas as pd
|
6 |
from utils import navigate_to
|
7 |
|
8 |
-
from model_inference import rate_video_frames,
|
9 |
from utils import read_important_frames, extract_audio
|
10 |
import numpy as np
|
11 |
|
@@ -61,28 +61,12 @@ async def landing_page():
|
|
61 |
|
62 |
|
63 |
async def model_inference_page():
|
64 |
-
|
65 |
-
df = pd.DataFrame([('Video_Text_Summary', 'Video_Rating_Scale')])
|
66 |
-
sl_df = st.table(df)
|
67 |
-
|
68 |
-
# check if audio is present and it's non-empty
|
69 |
-
if "audio_transcript" in sst:
|
70 |
-
|
71 |
-
video_summary_text = summarize_from_audio(sst["audio_transcript"])
|
72 |
-
|
73 |
-
if len(video_summary_text) > 0:
|
74 |
-
pass
|
75 |
-
else:
|
76 |
-
video_summary_text = "Sorry, we couldn't find any audio data from your video, hence couldn't generate any summary"
|
77 |
-
|
78 |
-
print("Time taken to generate text summary from audio in seconds: ", summarize_from_audio.total_time)
|
79 |
-
|
80 |
|
81 |
# check if frames are present and they are non-empty
|
82 |
if "important_frames" in sst:
|
83 |
|
84 |
important_frames = sst["important_frames"]
|
85 |
-
with st.spinner("Generating
|
86 |
video_rating_scale = rate_video_frames(important_frames)
|
87 |
|
88 |
if len(video_rating_scale) > 0:
|
@@ -90,13 +74,35 @@ async def model_inference_page():
|
|
90 |
else:
|
91 |
video_rating_scale = "Sorry, we couldn't find any images from your video, hence couldn't generate any summary"
|
92 |
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
sl_df.add_rows(
|
96 |
-
|
97 |
-
[( video_summary_text, video_rating_scale ) ]
|
98 |
-
|
99 |
-
)
|
100 |
|
101 |
st.button("Go Home",
|
102 |
on_click = navigate_to,
|
|
|
5 |
import pandas as pd
|
6 |
from utils import navigate_to
|
7 |
|
8 |
+
from model_inference import rate_video_frames,get_text_from_audio, summarize_from_text
|
9 |
from utils import read_important_frames, extract_audio
|
10 |
import numpy as np
|
11 |
|
|
|
61 |
|
62 |
|
63 |
async def model_inference_page():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
# check if frames are present and they are non-empty
|
66 |
if "important_frames" in sst:
|
67 |
|
68 |
important_frames = sst["important_frames"]
|
69 |
+
with st.spinner("Generating Movie Scale rating for your video"):
|
70 |
video_rating_scale = rate_video_frames(important_frames)
|
71 |
|
72 |
if len(video_rating_scale) > 0:
|
|
|
74 |
else:
|
75 |
video_rating_scale = "Sorry, we couldn't find any images from your video, hence couldn't generate any summary"
|
76 |
|
77 |
+
st.toast("Done")
|
78 |
+
st.header("Movie Scale Rating of Your Video: ", divider = True)
|
79 |
+
st.write(video_rating_scale)
|
80 |
+
st.markdown("************************")
|
81 |
+
|
82 |
+
|
83 |
+
# check if audio is present and it's non-empty
|
84 |
+
if "audio_transcript" in sst:
|
85 |
+
|
86 |
+
with st.spinner("Extracting text from audio file"):
|
87 |
+
video_raw_text = get_text_from_audio(sst["audio_transcript"])
|
88 |
+
st.toast("Done")
|
89 |
+
|
90 |
+
with st.spinner("Summarizing text from entire transcript"):
|
91 |
+
video_summary_text = summarize_from_text(video_raw_text)
|
92 |
+
st.toast("Done")
|
93 |
+
|
94 |
+
|
95 |
+
if len(video_summary_text) > 0:
|
96 |
+
pass
|
97 |
+
else:
|
98 |
+
video_summary_text = "Sorry, we couldn't find any audio data from your video, hence couldn't generate any summary"
|
99 |
+
|
100 |
+
print("Time taken to get raw text from audio in seconds: ", get_text_from_audio.total_time)
|
101 |
+
print("Time taken to generate text summary from raw text in seconds: ", summarize_from_text.total_time)
|
102 |
+
|
103 |
+
st.header("Audio Transcript summary of your video: ", divider = True)
|
104 |
+
st.write(video_summary_text)
|
105 |
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
st.button("Go Home",
|
108 |
on_click = navigate_to,
|
requirements.txt
CHANGED
@@ -23,6 +23,7 @@ PyYAML==6.0.2
|
|
23 |
safetensors==0.4.5
|
24 |
scipy==1.13.1
|
25 |
sentencepiece==0.2.0
|
|
|
26 |
smmap==5.0.2
|
27 |
sniffio==1.3.1
|
28 |
soundfile==0.13.1
|
|
|
23 |
safetensors==0.4.5
|
24 |
scipy==1.13.1
|
25 |
sentencepiece==0.2.0
|
26 |
+
onnxruntime-gpu==1.17.1
|
27 |
smmap==5.0.2
|
28 |
sniffio==1.3.1
|
29 |
soundfile==0.13.1
|
video_rating_siamesev2.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f073fb71728915d53cde75d316c3196c07bda3fe79af5acab6596bc397146b6
|
3 |
+
size 344064697
|