Amith Adiraju
Added exception handling, fixed models and tokenizer being on different devices, fixed boiler plate code
e8f31c7
from streamlit import session_state as sst | |
import time | |
import torch.nn.functional as F | |
import cv2 | |
import av | |
import heapq | |
import numpy as np | |
from preprocessing import preprocess_images | |
import time | |
from io import BytesIO | |
import torch | |
import soundfile as sf | |
import subprocess | |
from typing import List | |
prompt_audio_summarization = "This is a video transcript, tell me what is this about: " | |
def timer(func): | |
def wrapper(*args, **kwargs): | |
start = time.time() | |
result = func(*args, **kwargs) | |
duration = time.time() - start | |
wrapper.total_time += duration | |
print(f"Execution time of {func}: {duration}") | |
return result | |
wrapper.total_time = 0 | |
return wrapper | |
def navigate_to(page: str) -> None: | |
""" | |
Function to set the current page in the state of streamlit. A helper for | |
simulating navigation in streamlit. | |
Parameters: | |
page: str, required. | |
Returns: | |
None | |
""" | |
sst["page"] = page | |
def read_important_frames(video_bytes, top_k_frames) -> List: | |
# reading uploaded vidoe in memory | |
video_io = BytesIO(video_bytes) | |
# opening uploaded video frames | |
container = av.open(video_io, format='mp4') | |
prev_frame = None; important_frames = [] | |
# for each frame, find if it's movement worthy and push to heap for top_k movement frames | |
for frameId, frame in enumerate( container.decode(video=0) ): # Decode all frames | |
img = frame.to_ndarray(format="bgr24") # Convert frame to NumPy array (BGR format) | |
assert len(img.shape) == 3, f"Instead it is: {img.shape}" | |
if prev_frame is not None: | |
# Compute frame difference in gray scale for efficiency | |
diff = cv2.absdiff(prev_frame, img) | |
gray_diff = cv2.cvtColor(diff, cv2.COLOR_BGR2GRAY) | |
movement_score = np.sum(gray_diff) # Sum of pixel differences | |
processed_frame = preprocess_images(frame.to_ndarray(format="rgb24") , | |
224, | |
224 | |
) | |
# Thresholding to detect movement (adjust based on video) | |
if len(important_frames) < top_k_frames: # Tune threshold for motion sensitivity | |
heapq.heappush(important_frames, | |
(movement_score, frameId, processed_frame) | |
) | |
else: | |
heapq.heappushpop(important_frames, | |
(movement_score, frameId, processed_frame) | |
) | |
prev_frame = img # Update previous frame | |
# sorting top_k frames in chronological order of their appearance. This is quickest LOC. | |
important_frames = [item[2] for item in sorted(important_frames, key = lambda x: x[1])] | |
return important_frames | |
def extract_audio(video_bytes): | |
"""Extracts raw audio from a video file given as bytes without writing temp files.""" | |
# Run FFmpeg to extract raw WAV audio without writing a file | |
process = subprocess.run( | |
["ffmpeg", "-i", "pipe:0", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le", "-f", "wav", "pipe:1"], | |
input=video_bytes, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.DEVNULL | |
) | |
# Convert FFmpeg output to a BytesIO stream | |
audio_stream = BytesIO(process.stdout) | |
# Read the audio stream into a NumPy array | |
audio_array, sample_rate = sf.read(audio_stream, dtype="float32") | |
# Convert to PyTorch tensor (Whisper expects a torch.Tensor) | |
audio_tensor = torch.tensor(audio_array) | |
return audio_tensor | |
def batch_generator(array_list, batch_size=5): | |
""" | |
Generator that yields batches of 5 NumPy arrays stacked along the first dimension. | |
Parameters: | |
array_list (list of np.ndarray): List of NumPy arrays of shape (H, W, C). | |
batch_size (int): Number of arrays per batch (default is 5). | |
Yields: | |
np.ndarray: A batch of shape (batch_size, H, W, C). | |
""" | |
for i in range(0, len(array_list), batch_size): | |
batch = array_list[i:i + batch_size] | |
if len(batch) == batch_size: | |
yield np.stack(batch, axis=0) | |
def cosine_sim(emb1, emb2, threshold = 0.5): | |
cosine_sim = F.cosine_similarity(emb1, emb2) | |
counts = torch.count_nonzero(cosine_sim > threshold).numpy() | |
return (cosine_sim.mean(), counts) | |