|
|
|
|
|
import gradio as gr |
|
import cv2 |
|
import numpy as np |
|
import os |
|
from datetime import datetime |
|
|
|
from scenedetect import open_video, SceneManager |
|
from scenedetect.detectors import ContentDetector |
|
from moviepy.editor import VideoFileClip |
|
|
|
import random |
|
from functools import partial |
|
|
|
import clip |
|
import decord |
|
import nncore |
|
import torch |
|
import torchvision.transforms.functional as F |
|
from decord import VideoReader |
|
from nncore.engine import load_checkpoint |
|
from nncore.nn import build_model |
|
|
|
import pandas as pd |
|
|
|
def convert_time(seconds): |
|
minutes, seconds = divmod(round(max(seconds, 0)), 60) |
|
return f'{minutes:02d}:{seconds:02d}' |
|
|
|
|
|
TUNING_CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py' |
|
TUNING_WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth' |
|
|
|
|
|
def init_tuning_model(config, checkpoint): |
|
cfg = nncore.Config.from_file(config) |
|
cfg.model.init = True |
|
|
|
if checkpoint.startswith('http'): |
|
checkpoint = nncore.download(checkpoint, out_dir='checkpoints') |
|
|
|
model = build_model(cfg.model, dist=False).eval() |
|
model = load_checkpoint(model, checkpoint, warning=False) |
|
return model, cfg |
|
|
|
tuning_model, tuning_cfg = init_tuning_model(TUNING_CONFIG, TUNING_WEIGHT) |
|
|
|
|
|
def preprocess_video(video_path, cfg): |
|
decord.bridge.set_bridge('torch') |
|
vr = decord.VideoReader(video_path) |
|
stride = vr.get_avg_fps() / cfg.data.val.fps |
|
fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()] |
|
video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255 |
|
|
|
size = 336 if '336px' in cfg.model.arch else 224 |
|
h, w = video.size(-2), video.size(-1) |
|
s = min(h, w) |
|
x, y = round((h - s) / 2), round((w - s) / 2) |
|
video = video[..., x:x + s, y:y + s] |
|
video = F.resize(video, size=(size, size)) |
|
video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276)) |
|
return video.reshape(video.size(0), -1).unsqueeze(0) |
|
|
|
|
|
def calculate_saliency(video_path, query, model, cfg): |
|
if len(query) == 0: |
|
return None, None, 0 |
|
|
|
video = preprocess_video(video_path, cfg) |
|
query = clip.tokenize(query, truncate=True) |
|
|
|
device = next(model.parameters()).device |
|
data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps]) |
|
|
|
with torch.inference_mode(): |
|
pred = model(data) |
|
|
|
hd = pred['_out']['saliency'].cpu() |
|
hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).numpy() |
|
time_axis = np.arange(0, len(hd) * 2, 2) |
|
|
|
|
|
vr = decord.VideoReader(video_path) |
|
duration = len(vr) / vr.get_avg_fps() |
|
return hd, time_axis, duration |
|
|
|
|
|
def find_scenes(video_path, threshold, query): |
|
|
|
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") |
|
output_dir = f"output_{timestamp}" |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
saliency_scores, time_points, total_duration = calculate_saliency(video_path, query, tuning_model, tuning_cfg) |
|
if saliency_scores is None: |
|
raise gr.Error("请输入有效的文本查询") |
|
|
|
|
|
new_time = np.linspace(0, total_duration, num=int(total_duration*10)) |
|
interp_scores = np.interp(new_time, time_points, saliency_scores) |
|
|
|
|
|
filename = os.path.splitext(os.path.basename(video_path))[0] |
|
video = open_video(video_path) |
|
scene_manager = SceneManager() |
|
scene_manager.add_detector(ContentDetector(threshold=threshold)) |
|
scene_manager.detect_scenes(video, show_progress=True) |
|
scene_list = scene_manager.get_scene_list() |
|
|
|
if not scene_list: |
|
gr.Warning("No scenes detected in this video") |
|
return None, None, None, None |
|
|
|
|
|
processed_scenes = [] |
|
for i, shot in enumerate(scene_list): |
|
|
|
start_sec = shot[0].get_seconds() |
|
end_sec = shot[1].get_seconds() |
|
|
|
|
|
start_idx = np.searchsorted(new_time, start_sec, side='left') |
|
end_idx = np.searchsorted(new_time, end_sec, side='right') |
|
valid_scores = interp_scores[start_idx:end_idx] |
|
|
|
|
|
valid_scores = valid_scores[~np.isnan(valid_scores)] |
|
scene_score = valid_scores.mean() if len(valid_scores) > 0 else 0.0 |
|
|
|
|
|
scene_info = { |
|
"start": convert_time(start_sec), |
|
"end": convert_time(end_sec), |
|
"score": round(float(scene_score), 3), |
|
"start_sec": start_sec, |
|
"end_sec": end_sec |
|
} |
|
processed_scenes.append(scene_info) |
|
|
|
|
|
processed_scenes.sort(key=lambda x: x['score'], reverse=True) |
|
|
|
|
|
timecodes = [{"title": filename + ".mp4", "fps": scene_list[0][0].get_framerate()}] |
|
shots = [] |
|
stills = [] |
|
|
|
for idx, scene in enumerate(processed_scenes): |
|
|
|
shot_name = f"shot_{idx+1}_{filename}" |
|
target_name = os.path.join(output_dir, f"{shot_name}.mp4") |
|
|
|
|
|
with VideoFileClip(video_path) as clip: |
|
subclip = clip.subclip(scene['start_sec'], scene['end_sec']) |
|
subclip.write_videofile(target_name, |
|
codec="libx264", |
|
audio_codec="aac", |
|
threads=4, |
|
preset="fast", |
|
ffmpeg_params=["-crf", "23"]) |
|
|
|
|
|
vid = cv2.VideoCapture(video_path) |
|
vid.set(cv2.CAP_PROP_POS_MSEC, scene['start_sec']*1000) |
|
ret, frame = vid.read() |
|
img_path = os.path.join(output_dir, f"{shot_name}_screenshot.png") |
|
cv2.imwrite(img_path, frame) |
|
vid.release() |
|
|
|
|
|
timecodes.append({ |
|
"tc_in": scene['start'], |
|
"tc_out": scene['end'], |
|
"score": scene['score'], |
|
"shot_name": shot_name |
|
}) |
|
shots.append(target_name) |
|
stills.append((img_path, f'{shot_name}\nScore: {scene["score"]:.3f}')) |
|
|
|
|
|
plot_data = pd.DataFrame({ |
|
'x': new_time, |
|
'y': interp_scores |
|
}) |
|
|
|
return timecodes, shots, stills, plot_data |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
gr.Markdown(""" |
|
# 增强版场景编辑检测 |
|
新增功能: |
|
1. 输入文本查询分析视频内容相关性 |
|
2. 显示相关性时序折线图 |
|
3. 按相关性得分排序输出片段 |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
video_input = gr.Video(sources="upload", format="mp4", label="视频输入") |
|
query_input = gr.Textbox(label="文本查询", placeholder="输入描述视频内容的文本(5-15单词为佳)") |
|
threshold = gr.Slider(label="场景切换检测阈值", minimum=15.0, maximum=40.0, value=27.0) |
|
with gr.Row(): |
|
clear_button = gr.Button("清除") |
|
run_button = gr.Button("开始处理", variant="primary") |
|
plot_output = gr.LinePlot(x='x', y='y', x_title='时间(秒)', |
|
y_title='相关性得分', label='时序相关性分析') |
|
with gr.Column(): |
|
json_output = gr.JSON(label="场景分析结果(按得分排序)") |
|
|
|
file_output = gr.File(label="分割片段下载") |
|
gallery_output = gr.Gallery(label="场景缩略图", object_fit="cover", columns=3) |
|
|
|
run_button.click( |
|
fn=find_scenes, |
|
inputs=[video_input, threshold, query_input], |
|
outputs=[json_output, file_output, gallery_output, plot_output] |
|
) |
|
clear_button.click( |
|
fn=lambda: [None, 27, None, None, None, None], |
|
inputs=None, |
|
outputs=[video_input, threshold, query_input, json_output, file_output, gallery_output] |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["anime_kiss.mp4", 27, "A romantic kiss scene between two characters"], |
|
["anime_tear.mp4", 30, "An anime character is crying."] |
|
], |
|
inputs=[video_input, threshold, query_input], |
|
outputs=[json_output, file_output, gallery_output, plot_output], |
|
fn=find_scenes, |
|
cache_examples=False |
|
) |
|
|
|
demo.queue().launch(debug=True, share=True) |