Amith Adiraju
Added exception handling, fixed models and tokenizer being on different devices, fixed boiler plate code
e8f31c7
import streamlit as st | |
import torch | |
from utils import ( | |
prompt_audio_summarization, | |
timer, | |
cosine_sim | |
) | |
from transformers import BartForConditionalGeneration, BartTokenizer | |
import numpy as np | |
import whisper | |
from streamlit import session_state as sst | |
import onnxruntime | |
def get_text_from_audio(audio_tensors) -> str: | |
"""Transcribe multiple audio tensors in parallel using Whisper's batch processing.""" | |
# Transcribe the in-memory audio | |
audio_tensors = audio_tensors.to(sst['device']) | |
result = audio_transcriber_model.transcribe(audio_tensors | |
) | |
all_transcription_segments = result["text"] | |
return all_transcription_segments | |
def summarize_from_text(raw_transcription): | |
inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription, | |
return_tensors="pt", | |
max_length=1024, | |
truncation=True)\ | |
.to(sst['device']) | |
summary_ids = text_summarizer[1].generate(**inputs, | |
max_length=150, | |
min_length=30, | |
length_penalty=2.0, | |
num_beams=4 | |
) | |
return text_summarizer[0].decode(summary_ids[0], skip_special_tokens=True) | |
def rate_video_frames(video_frames): | |
""" | |
Classifies video frames into another category. | |
""" | |
inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3 | |
inputs_dict = {"frames": inp_frames} | |
video_frame_emb = video_rating_model.run(['emb'], inputs_dict)[0] | |
overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb, | |
emb2 = torch.tensor(video_frame_emb), | |
threshold=0.4 | |
) | |
perc_of_upg = count_upg / (len(video_frames)//5) | |
if perc_of_upg > 0.4: | |
return f"Out of {len(video_frames)} important moments of this video, {count_upg*5} moments contain under or at least PG content. Hence this video is suitable for kids & family." | |
else: | |
return f"Out of {len(video_frames)} important moments of this video, {(len(video_frames)//5 - count_upg)*5} moments contain at least PG-13 content.Hence parental guidance is strongly suggested for this video." | |
def load_models(): | |
sst['device'] = 'cuda' if torch.cuda.is_available() else 'cpu' | |
transcriber = whisper.load_model("base", device = sst['device']) | |
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(sst['device']) | |
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") | |
base_frame_emb = torch.tensor( | |
np.load('base_frame_medoid.npz')['arr'], | |
dtype = torch.float32, | |
device = sst['device'] | |
) | |
session = onnxruntime.InferenceSession("video_rating_siamesev2.onnx", | |
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
) | |
return ( | |
transcriber, (tokenizer, model), session, base_frame_emb | |
) | |
audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models() |