refactor_eh (#2)
Browse files- Added exception handling, fixed models and tokenizer being on different devices, fixed boiler plate code (e8f31c7e08bdb0472e2071aa660fea5ffea2eda7)
- README.md +1 -3
- app.py +0 -2
- model_inference.py +26 -111
- pages.py +38 -36
- runtime.txt +0 -1
- utils.py +2 -5
README.md
CHANGED
@@ -4,10 +4,8 @@ emoji: 👁
|
|
4 |
colorFrom: indigo
|
5 |
colorTo: indigo
|
6 |
sdk: streamlit
|
7 |
-
python_version: 3.9.6
|
8 |
sdk_version: 1.42.0
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
4 |
colorFrom: indigo
|
5 |
colorTo: indigo
|
6 |
sdk: streamlit
|
7 |
+
python_version: 3.9.6-slim
|
8 |
sdk_version: 1.42.0
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
---
|
|
|
|
app.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
-
import streamlit as st
|
2 |
from streamlit import session_state as sst
|
3 |
import asyncio
|
4 |
-
import torch
|
5 |
|
6 |
from pages import landing_page, model_inference_page
|
7 |
|
|
|
|
|
1 |
from streamlit import session_state as sst
|
2 |
import asyncio
|
|
|
3 |
|
4 |
from pages import landing_page, model_inference_page
|
5 |
|
model_inference.py
CHANGED
@@ -1,75 +1,19 @@
|
|
1 |
-
from transformers import pipeline
|
2 |
-
import torch
|
3 |
-
from PIL import Image
|
4 |
-
|
5 |
-
import torch.nn as nn
|
6 |
-
import torchvision.models as models
|
7 |
-
import torch.nn.functional as F
|
8 |
-
|
9 |
-
from PIL import Image
|
10 |
-
from utils import prompt_frame_summarization, assistant_role, prompt_audio_summarization
|
11 |
import streamlit as st
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
import numpy as np
|
15 |
import whisper
|
16 |
-
from utils import batch_generator, cosine_sim
|
17 |
from streamlit import session_state as sst
|
18 |
import onnxruntime
|
19 |
|
20 |
|
21 |
-
|
22 |
-
class SiameseNetwork(nn.Module):
|
23 |
-
def __init__(self, model_name="vit_b_16"):
|
24 |
-
super(SiameseNetwork, self).__init__()
|
25 |
-
|
26 |
-
self.encoder = models.vit_b_16(weights="IMAGENET1K_V1") # Pretrained ViT
|
27 |
-
self.encoder.heads = nn.Identity() # Remove classification head
|
28 |
-
|
29 |
-
self.fc = nn.Linear(768, 128) # Reduce to 128-d embedding
|
30 |
-
|
31 |
-
def forward(self, video_frames1, video_frames2):
|
32 |
-
"""
|
33 |
-
video1: (B, nf, H, W, C) # Batch of videos (50 frames each)
|
34 |
-
video2: (B, nf, H, W, C)
|
35 |
-
"""
|
36 |
-
B,num_frames,H,W,C = video_frames1.shape # (Batch, Channels, H, W)
|
37 |
-
|
38 |
-
# Flatten frames into batch dimension for ViT
|
39 |
-
video_frames1 = video_frames1.permute(0,1,4,2,3).reshape(B * num_frames, C,H,W)
|
40 |
-
video_frames2 = video_frames2.reshape(0,1,4,2,3).reshape(B * num_frames, C,H,W)
|
41 |
-
|
42 |
-
# Extract frame-level embeddings
|
43 |
-
emb1 = self.encoder(video_frames1) # (B*num_frames, 768)
|
44 |
-
emb2 = self.encoder(video_frames2)
|
45 |
-
|
46 |
-
# Reshape back to (B, T, 768) and average over T
|
47 |
-
#TODO: Change this to use LSTM instead of averaging
|
48 |
-
emb1 = emb1.reshape(B, num_frames, -1).mean(dim=1) # (B, 768)
|
49 |
-
emb2 = emb2.reshape(B, num_frames, -1).mean(dim=1)
|
50 |
-
|
51 |
-
# Pass through fully connected layer
|
52 |
-
emb1 = self.fc(emb1) # (B, 128)
|
53 |
-
emb2 = self.fc(emb2)
|
54 |
-
|
55 |
-
return emb1, emb2
|
56 |
-
|
57 |
-
def inference(self, video_frames):
|
58 |
-
"""
|
59 |
-
video: (B, 50, C, H, W)
|
60 |
-
"""
|
61 |
-
B, num_frames, H, W, C = video_frames.shape
|
62 |
-
|
63 |
-
video_frames = video_frames.permute(0,1,4,2,3).reshape(B * num_frames, C,H,W)
|
64 |
-
emb = self.encoder(video_frames)
|
65 |
-
emb = emb.reshape(B, num_frames, -1).mean(dim=1)
|
66 |
-
emb = self.fc(emb)
|
67 |
-
|
68 |
-
return emb
|
69 |
-
|
70 |
-
|
71 |
@timer
|
72 |
-
def get_text_from_audio(audio_tensors):
|
73 |
"""Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
|
74 |
# Transcribe the in-memory audio
|
75 |
audio_tensors = audio_tensors.to(sst['device'])
|
@@ -80,52 +24,21 @@ def get_text_from_audio(audio_tensors):
|
|
80 |
|
81 |
@timer
|
82 |
def summarize_from_text(raw_transcription):
|
83 |
-
|
84 |
-
summary = text_summarizer(prompt_audio_summarization + raw_transcription,
|
85 |
-
max_length=108,
|
86 |
-
min_length=36, do_sample=False)[0]['summary_text']
|
87 |
-
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
"""
|
104 |
-
|
105 |
-
processor = None
|
106 |
-
messages = None
|
107 |
-
model = None
|
108 |
-
tokenizer = None
|
109 |
-
|
110 |
-
if video_frames is None or len(video_frames) == 0:
|
111 |
-
return "Error: No video frames available."
|
112 |
-
|
113 |
-
|
114 |
-
# Ensure frames are properly formatted
|
115 |
-
video_frames = [Image.fromarray(frame.astype("uint8")) for frame in video_frames]
|
116 |
-
|
117 |
-
# Ensure correct format for processor
|
118 |
-
inputs = processor(messages, images=None, videos=[video_frames])
|
119 |
-
|
120 |
-
inputs.update({
|
121 |
-
"tokenizer": tokenizer,
|
122 |
-
"max_new_tokens": 54,
|
123 |
-
"decode_text": True,
|
124 |
-
})
|
125 |
-
|
126 |
-
summary_text = model.generate(**inputs)
|
127 |
-
|
128 |
-
return summary_text
|
129 |
|
130 |
@timer
|
131 |
def rate_video_frames(video_frames):
|
@@ -154,7 +67,9 @@ def rate_video_frames(video_frames):
|
|
154 |
def load_models():
|
155 |
sst['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
|
156 |
transcriber = whisper.load_model("base", device = sst['device'])
|
157 |
-
|
|
|
|
|
158 |
|
159 |
base_frame_emb = torch.tensor(
|
160 |
np.load('base_frame_medoid.npz')['arr'],
|
@@ -167,7 +82,7 @@ def load_models():
|
|
167 |
)
|
168 |
|
169 |
return (
|
170 |
-
transcriber,
|
171 |
)
|
172 |
|
173 |
audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from utils import (
|
4 |
+
prompt_audio_summarization,
|
5 |
+
timer,
|
6 |
+
cosine_sim
|
7 |
+
)
|
8 |
+
from transformers import BartForConditionalGeneration, BartTokenizer
|
9 |
import numpy as np
|
10 |
import whisper
|
|
|
11 |
from streamlit import session_state as sst
|
12 |
import onnxruntime
|
13 |
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
@timer
|
16 |
+
def get_text_from_audio(audio_tensors) -> str:
|
17 |
"""Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
|
18 |
# Transcribe the in-memory audio
|
19 |
audio_tensors = audio_tensors.to(sst['device'])
|
|
|
24 |
|
25 |
@timer
|
26 |
def summarize_from_text(raw_transcription):
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription,
|
29 |
+
return_tensors="pt",
|
30 |
+
max_length=1024,
|
31 |
+
truncation=True)\
|
32 |
+
.to(sst['device'])
|
33 |
+
|
34 |
+
summary_ids = text_summarizer[1].generate(**inputs,
|
35 |
+
max_length=150,
|
36 |
+
min_length=30,
|
37 |
+
length_penalty=2.0,
|
38 |
+
num_beams=4
|
39 |
+
)
|
40 |
+
|
41 |
+
return text_summarizer[0].decode(summary_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
@timer
|
44 |
def rate_video_frames(video_frames):
|
|
|
67 |
def load_models():
|
68 |
sst['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
|
69 |
transcriber = whisper.load_model("base", device = sst['device'])
|
70 |
+
|
71 |
+
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(sst['device'])
|
72 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
|
73 |
|
74 |
base_frame_emb = torch.tensor(
|
75 |
np.load('base_frame_medoid.npz')['arr'],
|
|
|
82 |
)
|
83 |
|
84 |
return (
|
85 |
+
transcriber, (tokenizer, model), session, base_frame_emb
|
86 |
)
|
87 |
|
88 |
audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models()
|
pages.py
CHANGED
@@ -1,13 +1,9 @@
|
|
1 |
import streamlit as st
|
2 |
from streamlit import session_state as sst
|
3 |
-
import time
|
4 |
-
|
5 |
-
import pandas as pd
|
6 |
from utils import navigate_to
|
7 |
|
8 |
from model_inference import rate_video_frames,get_text_from_audio, summarize_from_text
|
9 |
from utils import read_important_frames, extract_audio
|
10 |
-
import numpy as np
|
11 |
|
12 |
|
13 |
# Define size limits (adjust based on your system)
|
@@ -33,25 +29,35 @@ async def landing_page():
|
|
33 |
else:
|
34 |
# bytes object which can be translated to audio or video
|
35 |
video_bytes = uploaded_file.read()
|
36 |
-
|
|
|
37 |
with st.spinner("Getting most important moments from your video."):
|
38 |
-
important_frames = read_important_frames(video_bytes, 100)
|
39 |
-
st.success(f"Got important moments.")
|
40 |
|
41 |
-
|
|
|
42 |
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
with st.spinner("Getting audio transcript from your video for summary"):
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
52 |
|
53 |
-
# add audio transcript to session state
|
54 |
-
sst["audio_transcript"] = audio_transcript_bytes
|
55 |
|
56 |
st.button("Summarize & Analyze Video",
|
57 |
on_click = navigate_to,
|
@@ -67,13 +73,11 @@ async def model_inference_page():
|
|
67 |
|
68 |
important_frames = sst["important_frames"]
|
69 |
with st.spinner("Generating Movie Scale rating for your video"):
|
70 |
-
|
|
|
|
|
|
|
71 |
|
72 |
-
if len(video_rating_scale) > 0:
|
73 |
-
pass
|
74 |
-
else:
|
75 |
-
video_rating_scale = "Sorry, we couldn't find any images from your video, hence couldn't generate any summary"
|
76 |
-
|
77 |
st.toast("Done")
|
78 |
st.header("Movie Scale Rating of Your Video: ", divider = True)
|
79 |
st.write(video_rating_scale)
|
@@ -84,21 +88,19 @@ async def model_inference_page():
|
|
84 |
if "audio_transcript" in sst:
|
85 |
|
86 |
with st.spinner("Extracting text from audio file"):
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
video_summary_text = summarize_from_text(video_raw_text)
|
92 |
st.toast("Done")
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
print("Time taken to generate text summary from raw text in seconds: ", summarize_from_text.total_time)
|
102 |
|
103 |
st.header("Audio Transcript summary of your video: ", divider = True)
|
104 |
st.write(video_summary_text)
|
|
|
1 |
import streamlit as st
|
2 |
from streamlit import session_state as sst
|
|
|
|
|
|
|
3 |
from utils import navigate_to
|
4 |
|
5 |
from model_inference import rate_video_frames,get_text_from_audio, summarize_from_text
|
6 |
from utils import read_important_frames, extract_audio
|
|
|
7 |
|
8 |
|
9 |
# Define size limits (adjust based on your system)
|
|
|
29 |
else:
|
30 |
# bytes object which can be translated to audio or video
|
31 |
video_bytes = uploaded_file.read()
|
32 |
+
|
33 |
+
# Try to get important frames from this video, if not don't add this key for further inference processing
|
34 |
with st.spinner("Getting most important moments from your video."):
|
|
|
|
|
35 |
|
36 |
+
try:
|
37 |
+
important_frames = read_important_frames(video_bytes, 100)
|
38 |
|
39 |
+
st.success(f"Got important moments.")
|
40 |
+
|
41 |
+
# add important frames to session state and redirect to model inference page
|
42 |
+
sst["important_frames"] = important_frames
|
43 |
+
|
44 |
+
except Exception as e:
|
45 |
+
st.write(f"Sorry couldn't extract important frames from this video & can't rate this on movie scale, because of error: {e}")
|
46 |
+
|
47 |
+
|
48 |
+
# Try to get audio from this video, if not don't add this key for further inference processing
|
49 |
with st.spinner("Getting audio transcript from your video for summary"):
|
50 |
+
try:
|
51 |
+
audio_transcript_bytes = extract_audio(video_bytes)
|
52 |
+
|
53 |
+
st.success(f"Got audio transcript.")
|
54 |
|
55 |
+
# add audio transcript to session state
|
56 |
+
sst["audio_transcript"] = audio_transcript_bytes
|
57 |
+
|
58 |
+
except Exception as e:
|
59 |
+
st.write(f"Sorry couldn't extract audio from this video & can't rate summarize it, because of error: {e}")
|
60 |
|
|
|
|
|
61 |
|
62 |
st.button("Summarize & Analyze Video",
|
63 |
on_click = navigate_to,
|
|
|
73 |
|
74 |
important_frames = sst["important_frames"]
|
75 |
with st.spinner("Generating Movie Scale rating for your video"):
|
76 |
+
try:
|
77 |
+
video_rating_scale = rate_video_frames(important_frames)
|
78 |
+
except Exception as e:
|
79 |
+
video_rating_scale = f"Sorry, we couldn't generate rating of your video because of this error: {e} "
|
80 |
|
|
|
|
|
|
|
|
|
|
|
81 |
st.toast("Done")
|
82 |
st.header("Movie Scale Rating of Your Video: ", divider = True)
|
83 |
st.write(video_rating_scale)
|
|
|
88 |
if "audio_transcript" in sst:
|
89 |
|
90 |
with st.spinner("Extracting text from audio file"):
|
91 |
+
try:
|
92 |
+
video_summary_text = get_text_from_audio(sst["audio_transcript"])
|
93 |
+
except Exception as e:
|
94 |
+
video_summary_text = f"Sorry, we couldn't extract text from audio of this file because of this error: {e} "
|
|
|
95 |
st.toast("Done")
|
96 |
|
97 |
+
if video_summary_text[:5] != "Sorry":
|
98 |
+
with st.spinner("Summarizing text from entire transcript"):
|
99 |
+
try:
|
100 |
+
video_summary_text = summarize_from_text(video_summary_text)
|
101 |
+
except Exception as e:
|
102 |
+
video_summary_text = f"Sorry, we couldn't summarize text from audio of this file because of this error: {e} "
|
103 |
+
st.toast("Done")
|
|
|
104 |
|
105 |
st.header("Audio Transcript summary of your video: ", divider = True)
|
106 |
st.write(video_summary_text)
|
runtime.txt
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
3.9.*
|
|
|
|
utils.py
CHANGED
@@ -10,17 +10,14 @@ import numpy as np
|
|
10 |
from preprocessing import preprocess_images
|
11 |
import time
|
12 |
|
13 |
-
import io
|
14 |
from io import BytesIO
|
15 |
import torch
|
16 |
import soundfile as sf
|
17 |
import subprocess
|
18 |
from typing import List
|
19 |
|
20 |
-
|
21 |
-
prompt_frame_summarization = "These are important frames of a video file. Please generate summary such that end user gets gist of what the video is about."
|
22 |
prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
|
23 |
-
|
24 |
|
25 |
def timer(func):
|
26 |
def wrapper(*args, **kwargs):
|
@@ -52,7 +49,7 @@ def navigate_to(page: str) -> None:
|
|
52 |
def read_important_frames(video_bytes, top_k_frames) -> List:
|
53 |
|
54 |
# reading uploaded vidoe in memory
|
55 |
-
video_io =
|
56 |
|
57 |
# opening uploaded video frames
|
58 |
container = av.open(video_io, format='mp4')
|
|
|
10 |
from preprocessing import preprocess_images
|
11 |
import time
|
12 |
|
|
|
13 |
from io import BytesIO
|
14 |
import torch
|
15 |
import soundfile as sf
|
16 |
import subprocess
|
17 |
from typing import List
|
18 |
|
|
|
|
|
19 |
prompt_audio_summarization = "This is a video transcript, tell me what is this about: "
|
20 |
+
|
21 |
|
22 |
def timer(func):
|
23 |
def wrapper(*args, **kwargs):
|
|
|
49 |
def read_important_frames(video_bytes, top_k_frames) -> List:
|
50 |
|
51 |
# reading uploaded vidoe in memory
|
52 |
+
video_io = BytesIO(video_bytes)
|
53 |
|
54 |
# opening uploaded video frames
|
55 |
container = av.open(video_io, format='mp4')
|