shaoyent's picture
Update app.py
d37aaef
import os
import cv2
import gradio as gr
import numpy as np
import json
import pickle
from PIL import Image
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import BridgeTowerProcessor
from tqdm import tqdm
from bridgetower_custom import BridgeTowerTextFeatureExtractor, BridgeTowerForITC
import faiss
import webvtt
from pytube import YouTube
from youtube_transcript_api import YouTubeTranscriptApi
from youtube_transcript_api.formatters import WebVTTFormatter
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
model_name = 'BridgeTower/bridgetower-large-itm-mlm-itc'
model = BridgeTowerForITC.from_pretrained(model_name).to(device)
text_model = BridgeTowerTextFeatureExtractor.from_pretrained(model_name).to(device)
processor = BridgeTowerProcessor.from_pretrained(model_name)
def download_video(video_url, path='/tmp/'):
yt = YouTube(video_url)
yt = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first()
if not os.path.exists(path):
os.makedirs(path)
filepath = os.path.join(path, yt.default_filename)
if not os.path.exists(filepath):
print('Downloading video from YouTube...')
yt.download(path)
return filepath
# Get transcript in webvtt
def get_transcript_vtt(video_id, path='/tmp'):
filepath = os.path.join(path,'test_vm.vtt')
if os.path.exists(filepath):
return filepath
transcript = YouTubeTranscriptApi.get_transcript(video_id)
formatter = WebVTTFormatter()
webvtt_formatted = formatter.format_transcript(transcript)
with open(filepath, 'w', encoding='utf-8') as webvtt_file:
webvtt_file.write(webvtt_formatted)
webvtt_file.close()
return filepath
# https://stackoverflow.com/a/57781047
# Resizes a image and maintains aspect ratio
def maintain_aspect_ratio_resize(image, width=None, height=None, inter=cv2.INTER_AREA):
# Grab the image size and initialize dimensions
dim = None
(h, w) = image.shape[:2]
# Return original image if no need to resize
if width is None and height is None:
return image
# We are resizing height if width is none
if width is None:
# Calculate the ratio of the height and construct the dimensions
r = height / float(h)
dim = (int(w * r), height)
# We are resizing width if height is none
else:
# Calculate the ratio of the width and construct the dimensions
r = width / float(w)
dim = (width, int(h * r))
# Return the resized image
return cv2.resize(image, dim, interpolation=inter)
def time_to_frame(time, fps):
'''
convert time in seconds into frame number
'''
return int(time * fps - 1)
def str2time(strtime):
strtime = strtime.strip('"')
hrs, mins, seconds = [float(c) for c in strtime.split(':')]
total_seconds = hrs * 60**2 + mins * 60 + seconds
return total_seconds
def collate_fn(batch_list):
batch = {}
batch['input_ids'] = pad_sequence([encoding['input_ids'].squeeze(0) for encoding in batch_list], batch_first=True)
batch['attention_mask'] = pad_sequence([encoding['attention_mask'].squeeze(0) for encoding in batch_list], batch_first=True)
batch['pixel_values'] = torch.cat([encoding['pixel_values'] for encoding in batch_list], dim=0)
batch['pixel_mask'] = torch.cat([encoding['pixel_mask'] for encoding in batch_list], dim=0)
return batch
def extract_images_and_embeds(video_id, video_path, subtitles, output, expanded=False, batch_size=2, progress=gr.Progress()):
if os.path.exists(os.path.join(output, 'embeddings.pkl')):
return
os.makedirs(output, exist_ok=True)
os.makedirs(os.path.join(output, 'frames'), exist_ok=True)
os.makedirs(os.path.join(output, 'frames_thumb'), exist_ok=True)
count = 0
vidcap = cv2.VideoCapture(video_path)
# Get the frames per second
fps = vidcap.get(cv2.CAP_PROP_FPS)
# Get the total numer of frames in the video.
frame_count = vidcap.get(cv2.CAP_PROP_FRAME_COUNT)
# print(fps, frame_count)
frame_number = 0
count = 0
anno = []
embeddings = []
batch_list = []
vtt = webvtt.read(subtitles)
for idx, caption in enumerate(tqdm(vtt, total=vtt.total_length, desc="Generating embeddings")):
st_time = str2time(caption.start)
ed_time = str2time(caption.end)
mid_time = (ed_time + st_time) / 2
text = caption.text.replace('\n', ' ')
if expanded :
raise NotImplementedError
frame_no = time_to_frame(mid_time, fps)
mid_time_ms = mid_time * 1000
# vidcap.set(1, frame_no) # added this line
vidcap.set(cv2.CAP_PROP_POS_MSEC, mid_time_ms)
print('Read a new frame: ', idx, mid_time, frame_no, text)
success, frame = vidcap.read()
if success:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
img_fname = f'{video_id}_{idx:06d}'
img_fpath = os.path.join(output, 'frames', img_fname + '.jpg')
# image = maintain_aspect_ratio_resize(image, height=350) # save frame as JPEG file
# cv2.imwrite( img_fpath, image) # save frame as JPEG file
count += 1
anno.append({
'image_id': idx,
'img_fname': img_fname,
'caption': text,
'time': mid_time_ms,
'frame_no': frame_no
})
encoding = processor(frame, text, return_tensors="pt").to(device)
encoding['text'] = text
encoding['image_filepath'] = img_fpath
encoding['start_time'] = caption.start
encoding['time'] = mid_time_ms
batch_list.append(encoding)
else:
break
if len(batch_list) == batch_size:
batch = collate_fn(batch_list)
with torch.no_grad():
outputs = model(**batch, output_hidden_states=True)
for i in range(batch_size):
embeddings.append({
'embeddings':outputs.logits[i,2,:].detach().cpu().numpy(),
'text': batch_list[i]['text'],
'image_filepath': batch_list[i]['image_filepath'],
'start_time': batch_list[i]['start_time'],
'time': batch_list[i]['time'],
})
batch_list = []
if batch_list:
batch = collate_fn(batch_list)
with torch.no_grad():
outputs = model(**batch, output_hidden_states=True)
for i in range(len(batch_list)):
embeddings.append({
'embeddings':outputs.logits[i,2,:].detach().cpu().numpy(),
'text': batch_list[i]['text'],
'image_filepath': batch_list[i]['image_filepath'],
'start_time': batch_list[i]['start_time'],
'time': batch_list[i]['time'],
})
batch_list = []
with open(os.path.join(output, 'annotations.json'), 'w') as fh:
json.dump(anno, fh)
with open(os.path.join(output, 'embeddings.pkl'), 'wb') as fh:
pickle.dump(embeddings, fh)
def run_query(video_path, text_query, path='/tmp'):
vidcap = cv2.VideoCapture(video_path)
embeddings_filepath = os.path.join(path, 'embeddings.pkl')
faiss_filepath = os.path.join(path, 'faiss_index.pkl')
embeddings = pickle.load(open(embeddings_filepath, 'rb'))
if os.path.exists(faiss_filepath):
faiss_index = pickle.load(open(faiss_filepath, 'rb'))
else :
embs = [emb['embeddings'] for emb in embeddings]
vectors = np.stack(embs, axis=0)
num_vectors, vector_dim = vectors.shape
faiss_index = faiss.IndexFlatIP(vector_dim)
faiss_index.add(vectors)
pickle.dump(faiss_index, open(faiss_filepath, 'wb'))
print('Processing query')
encoding = processor.tokenizer(text_query, return_tensors="pt").to(device)
with torch.no_grad():
outputs = text_model(**encoding)
emb_query = outputs.cpu().numpy()
print('Running FAISS search')
_, I = faiss_index.search(emb_query, 6)
clip_images = []
transcripts = []
for idx in I[0]:
# frame_no = embeddings[idx]['frame_no']
# vidcap.set(1, frame_no) # added this line
frame_timestamp = embeddings[idx]['time']
vidcap.set(cv2.CAP_PROP_POS_MSEC, frame_timestamp)
success, frame = vidcap.read()
if success:
frame = maintain_aspect_ratio_resize(frame, height=400)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
clip_images.append(frame)
transcripts.append(f"({embeddings[idx]['start_time']}) {embeddings[idx]['text']}")
return clip_images, transcripts
#https://stackoverflow.com/a/7936523
def get_video_id_from_url(video_url):
"""
Examples:
- http://youtu.be/SA2iWivDJiE
- http://www.youtube.com/watch?v=_oPAwA_Udwc&feature=feedu
- http://www.youtube.com/embed/SA2iWivDJiE
- http://www.youtube.com/v/SA2iWivDJiE?version=3&hl=en_US
"""
import urllib.parse
url = urllib.parse.urlparse(video_url)
if url.hostname == 'youtu.be':
return url.path[1:]
if url.hostname in ('www.youtube.com', 'youtube.com'):
if url.path == '/watch':
p = urllib.parse.parse_qs(url.query)
return p['v'][0]
if url.path[:7] == '/embed/':
return url.path.split('/')[2]
if url.path[:3] == '/v/':
return url.path.split('/')[2]
return None
def process(video_url, text_query, progress=gr.Progress(track_tqdm=True)):
tmp_dir = os.environ.get('TMPDIR', '/tmp')
video_id = get_video_id_from_url(video_url)
output_dir = os.path.join(tmp_dir, video_id)
video_file = download_video(video_url, path=output_dir)
subtitles = get_transcript_vtt(video_id, path=output_dir)
extract_images_and_embeds(video_id=video_id,
video_path=video_file,
subtitles=subtitles,
output=output_dir,
expanded=False,
batch_size=8,
progress=progress,
)
frame_paths, transcripts = run_query(video_file, text_query, path=output_dir)
return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)]
description = "This Space lets you run semantic search on a video."
with gr.Blocks() as demo:
gr.Markdown(description)
with gr.Row():
with gr.Column():
video_url = gr.Text(label="Youtube url")
text_query = gr.Text(label="Text query")
btn = gr.Button("Run query")
video_player = gr.Video(label="Video")
with gr.Row():
gallery = gr.Gallery(label="Images").style(grid=6)
gr.Examples(
examples=[
['https://www.youtube.com/watch?v=CvjoXdC-WkM','wedding'],
['https://www.youtube.com/watch?v=fWs2dWcNGu0', 'cheesecake'],
['https://www.youtube.com/watch?v=rmPpNsx4yAk', 'bunny'],
['https://www.youtube.com/watch?v=KCFYf4TJdN0' ,'sandwich'],
],
inputs=[video_url, text_query],
)
btn.click(fn=process,
inputs=[video_url, text_query],
outputs=[video_player, gallery],
)
try:
demo.queue(concurrency_count=3)
demo.launch(share=True)
except:
demo.launch()