svjack's picture
Upload folder using huggingface_hub
3b36ece verified
raw
history blame
9.04 kB
# 新增依赖
import gradio as gr
import cv2
import numpy as np
import os
from datetime import datetime # 新增时间模块
from scenedetect import open_video, SceneManager
from scenedetect.detectors import ContentDetector
from moviepy.editor import VideoFileClip
import random
from functools import partial
import clip
import decord
import nncore
import torch
import torchvision.transforms.functional as F
from decord import VideoReader
from nncore.engine import load_checkpoint
from nncore.nn import build_model
import pandas as pd
def convert_time(seconds):
minutes, seconds = divmod(round(max(seconds, 0)), 60)
return f'{minutes:02d}:{seconds:02d}'
# R2-Tuning 模型配置
TUNING_CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
TUNING_WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth'
# 初始化R2-Tuning模型
def init_tuning_model(config, checkpoint):
cfg = nncore.Config.from_file(config)
cfg.model.init = True
if checkpoint.startswith('http'):
checkpoint = nncore.download(checkpoint, out_dir='checkpoints')
model = build_model(cfg.model, dist=False).eval()
model = load_checkpoint(model, checkpoint, warning=False)
return model, cfg
tuning_model, tuning_cfg = init_tuning_model(TUNING_CONFIG, TUNING_WEIGHT)
# 视频预处理函数(来自第一个应用)
def preprocess_video(video_path, cfg):
decord.bridge.set_bridge('torch')
vr = decord.VideoReader(video_path)
stride = vr.get_avg_fps() / cfg.data.val.fps
fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
size = 336 if '336px' in cfg.model.arch else 224
h, w = video.size(-2), video.size(-1)
s = min(h, w)
x, y = round((h - s) / 2), round((w - s) / 2)
video = video[..., x:x + s, y:y + s]
video = F.resize(video, size=(size, size))
video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276))
return video.reshape(video.size(0), -1).unsqueeze(0)
# 在calculate_saliency函数中增加视频时长返回
def calculate_saliency(video_path, query, model, cfg):
if len(query) == 0:
return None, None, 0
video = preprocess_video(video_path, cfg)
query = clip.tokenize(query, truncate=True)
device = next(model.parameters()).device
data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
with torch.inference_mode():
pred = model(data)
hd = pred['_out']['saliency'].cpu()
hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).numpy()
time_axis = np.arange(0, len(hd) * 2, 2)
# 获取视频总时长
vr = decord.VideoReader(video_path)
duration = len(vr) / vr.get_avg_fps()
return hd, time_axis, duration
# 修改后的场景检测函数
def find_scenes(video_path, threshold, query):
# 初始化输出目录
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
output_dir = f"output_{timestamp}"
os.makedirs(output_dir, exist_ok=True)
# 计算saliency score
saliency_scores, time_points, total_duration = calculate_saliency(video_path, query, tuning_model, tuning_cfg)
if saliency_scores is None:
raise gr.Error("请输入有效的文本查询")
# 线性插值处理
new_time = np.linspace(0, total_duration, num=int(total_duration*10)) # 每0.5秒一个点
interp_scores = np.interp(new_time, time_points, saliency_scores)
# 场景检测(原逻辑)
filename = os.path.splitext(os.path.basename(video_path))[0]
video = open_video(video_path)
scene_manager = SceneManager()
scene_manager.add_detector(ContentDetector(threshold=threshold))
scene_manager.detect_scenes(video, show_progress=True)
scene_list = scene_manager.get_scene_list()
if not scene_list:
gr.Warning("No scenes detected in this video")
return None, None, None, None
# 处理每个场景片段
processed_scenes = []
for i, shot in enumerate(scene_list):
# 获取时间信息
start_sec = shot[0].get_seconds()
end_sec = shot[1].get_seconds()
# 使用插值后的得分计算
start_idx = np.searchsorted(new_time, start_sec, side='left')
end_idx = np.searchsorted(new_time, end_sec, side='right')
valid_scores = interp_scores[start_idx:end_idx]
# 处理可能的空值(理论上插值后不会出现)
valid_scores = valid_scores[~np.isnan(valid_scores)]
scene_score = valid_scores.mean() if len(valid_scores) > 0 else 0.0
# 保存片段信息
scene_info = {
"start": convert_time(start_sec),
"end": convert_time(end_sec),
"score": round(float(scene_score), 3),
"start_sec": start_sec,
"end_sec": end_sec
}
processed_scenes.append(scene_info)
# 按得分排序
processed_scenes.sort(key=lambda x: x['score'], reverse=True)
# 生成输出内容
timecodes = [{"title": filename + ".mp4", "fps": scene_list[0][0].get_framerate()}]
shots = []
stills = []
for idx, scene in enumerate(processed_scenes):
# 生成片段名称
shot_name = f"shot_{idx+1}_{filename}"
target_name = os.path.join(output_dir, f"{shot_name}.mp4")
# 分割视频片段
with VideoFileClip(video_path) as clip:
subclip = clip.subclip(scene['start_sec'], scene['end_sec'])
subclip.write_videofile(target_name,
codec="libx264",
audio_codec="aac",
threads=4,
preset="fast",
ffmpeg_params=["-crf", "23"])
# 截取缩略图
vid = cv2.VideoCapture(video_path)
vid.set(cv2.CAP_PROP_POS_MSEC, scene['start_sec']*1000)
ret, frame = vid.read()
img_path = os.path.join(output_dir, f"{shot_name}_screenshot.png")
cv2.imwrite(img_path, frame)
vid.release()
# 保存结果(带片段名称的标签)
timecodes.append({
"tc_in": scene['start'],
"tc_out": scene['end'],
"score": scene['score'],
"shot_name": shot_name
})
shots.append(target_name)
stills.append((img_path, f'{shot_name}\nScore: {scene["score"]:.3f}'))
# 生成折线图数据(使用插值后的数据)
plot_data = pd.DataFrame({
'x': new_time,
'y': interp_scores
})
return timecodes, shots, stills, plot_data
# 修改后的界面
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("""
# 增强版场景编辑检测
新增功能:
1. 输入文本查询分析视频内容相关性
2. 显示相关性时序折线图
3. 按相关性得分排序输出片段
""")
with gr.Row():
with gr.Column():
video_input = gr.Video(sources="upload", format="mp4", label="视频输入")
query_input = gr.Textbox(label="文本查询", placeholder="输入描述视频内容的文本(5-15单词为佳)")
threshold = gr.Slider(label="场景切换检测阈值", minimum=15.0, maximum=40.0, value=27.0)
with gr.Row():
clear_button = gr.Button("清除")
run_button = gr.Button("开始处理", variant="primary")
plot_output = gr.LinePlot(x='x', y='y', x_title='时间(秒)',
y_title='相关性得分', label='时序相关性分析')
with gr.Column():
json_output = gr.JSON(label="场景分析结果(按得分排序)")
file_output = gr.File(label="分割片段下载")
gallery_output = gr.Gallery(label="场景缩略图", object_fit="cover", columns=3)
run_button.click(
fn=find_scenes,
inputs=[video_input, threshold, query_input],
outputs=[json_output, file_output, gallery_output, plot_output]
)
clear_button.click(
fn=lambda: [None, 27, None, None, None, None],
inputs=None,
outputs=[video_input, threshold, query_input, json_output, file_output, gallery_output]
)
gr.Examples(
examples=[
["anime_kiss.mp4", 27, "A romantic kiss scene between two characters"],
["anime_tear.mp4", 30, "An anime character is crying."]
],
inputs=[video_input, threshold, query_input],
outputs=[json_output, file_output, gallery_output, plot_output],
fn=find_scenes,
cache_examples=False
)
demo.queue().launch(debug=True, share=True)