Video_Summary_Beta / model_inference.py
Amith Adiraju
Added exception handling, fixed models and tokenizer being on different devices, fixed boiler plate code
e8f31c7
raw
history blame
3.62 kB
import streamlit as st
import torch
from utils import (
prompt_audio_summarization,
timer,
cosine_sim
)
from transformers import BartForConditionalGeneration, BartTokenizer
import numpy as np
import whisper
from streamlit import session_state as sst
import onnxruntime
@timer
def get_text_from_audio(audio_tensors) -> str:
"""Transcribe multiple audio tensors in parallel using Whisper's batch processing."""
# Transcribe the in-memory audio
audio_tensors = audio_tensors.to(sst['device'])
result = audio_transcriber_model.transcribe(audio_tensors
)
all_transcription_segments = result["text"]
return all_transcription_segments
@timer
def summarize_from_text(raw_transcription):
inputs = text_summarizer[0](prompt_audio_summarization + raw_transcription,
return_tensors="pt",
max_length=1024,
truncation=True)\
.to(sst['device'])
summary_ids = text_summarizer[1].generate(**inputs,
max_length=150,
min_length=30,
length_penalty=2.0,
num_beams=4
)
return text_summarizer[0].decode(summary_ids[0], skip_special_tokens=True)
@timer
def rate_video_frames(video_frames):
"""
Classifies video frames into another category.
"""
inp_frames = np.array(video_frames, dtype = np.float32).reshape(len(video_frames)//5, 5, 224,224,3)# 20,5,224,224,3
inputs_dict = {"frames": inp_frames}
video_frame_emb = video_rating_model.run(['emb'], inputs_dict)[0]
overall_sim, count_upg = cosine_sim(emb1 = base_frame_emb,
emb2 = torch.tensor(video_frame_emb),
threshold=0.4
)
perc_of_upg = count_upg / (len(video_frames)//5)
if perc_of_upg > 0.4:
return f"Out of {len(video_frames)} important moments of this video, {count_upg*5} moments contain under or at least PG content. Hence this video is suitable for kids & family."
else:
return f"Out of {len(video_frames)} important moments of this video, {(len(video_frames)//5 - count_upg)*5} moments contain at least PG-13 content.Hence parental guidance is strongly suggested for this video."
@st.cache_resource
def load_models():
sst['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
transcriber = whisper.load_model("base", device = sst['device'])
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn").to(sst['device'])
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
base_frame_emb = torch.tensor(
np.load('base_frame_medoid.npz')['arr'],
dtype = torch.float32,
device = sst['device']
)
session = onnxruntime.InferenceSession("video_rating_siamesev2.onnx",
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
return (
transcriber, (tokenizer, model), session, base_frame_emb
)
audio_transcriber_model, text_summarizer, video_rating_model,base_frame_emb = load_models()