# 新增依赖 import gradio as gr import cv2 import numpy as np import os from datetime import datetime # 新增时间模块 from scenedetect import open_video, SceneManager from scenedetect.detectors import ContentDetector from moviepy.editor import VideoFileClip import random from functools import partial import clip import decord import nncore import torch import torchvision.transforms.functional as F from decord import VideoReader from nncore.engine import load_checkpoint from nncore.nn import build_model import pandas as pd def convert_time(seconds): minutes, seconds = divmod(round(max(seconds, 0)), 60) return f'{minutes:02d}:{seconds:02d}' # R2-Tuning 模型配置 TUNING_CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py' TUNING_WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth' # 初始化R2-Tuning模型 def init_tuning_model(config, checkpoint): cfg = nncore.Config.from_file(config) cfg.model.init = True if checkpoint.startswith('http'): checkpoint = nncore.download(checkpoint, out_dir='checkpoints') model = build_model(cfg.model, dist=False).eval() model = load_checkpoint(model, checkpoint, warning=False) return model, cfg tuning_model, tuning_cfg = init_tuning_model(TUNING_CONFIG, TUNING_WEIGHT) # 视频预处理函数(来自第一个应用) def preprocess_video(video_path, cfg): decord.bridge.set_bridge('torch') vr = decord.VideoReader(video_path) stride = vr.get_avg_fps() / cfg.data.val.fps fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()] video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255 size = 336 if '336px' in cfg.model.arch else 224 h, w = video.size(-2), video.size(-1) s = min(h, w) x, y = round((h - s) / 2), round((w - s) / 2) video = video[..., x:x + s, y:y + s] video = F.resize(video, size=(size, size)) video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276)) return video.reshape(video.size(0), -1).unsqueeze(0) # 在calculate_saliency函数中增加视频时长返回 def calculate_saliency(video_path, query, model, cfg): if len(query) == 0: return None, None, 0 video = preprocess_video(video_path, cfg) query = clip.tokenize(query, truncate=True) device = next(model.parameters()).device data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps]) with torch.inference_mode(): pred = model(data) hd = pred['_out']['saliency'].cpu() hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).numpy() time_axis = np.arange(0, len(hd) * 2, 2) # 获取视频总时长 vr = decord.VideoReader(video_path) duration = len(vr) / vr.get_avg_fps() return hd, time_axis, duration # 修改后的场景检测函数 def find_scenes(video_path, threshold, query): # 初始化输出目录 timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") output_dir = f"output_{timestamp}" os.makedirs(output_dir, exist_ok=True) # 计算saliency score saliency_scores, time_points, total_duration = calculate_saliency(video_path, query, tuning_model, tuning_cfg) if saliency_scores is None: raise gr.Error("请输入有效的文本查询") # 线性插值处理 new_time = np.linspace(0, total_duration, num=int(total_duration*10)) # 每0.5秒一个点 interp_scores = np.interp(new_time, time_points, saliency_scores) # 场景检测(原逻辑) filename = os.path.splitext(os.path.basename(video_path))[0] video = open_video(video_path) scene_manager = SceneManager() scene_manager.add_detector(ContentDetector(threshold=threshold)) scene_manager.detect_scenes(video, show_progress=True) scene_list = scene_manager.get_scene_list() if not scene_list: gr.Warning("No scenes detected in this video") return None, None, None, None # 处理每个场景片段 processed_scenes = [] for i, shot in enumerate(scene_list): # 获取时间信息 start_sec = shot[0].get_seconds() end_sec = shot[1].get_seconds() # 使用插值后的得分计算 start_idx = np.searchsorted(new_time, start_sec, side='left') end_idx = np.searchsorted(new_time, end_sec, side='right') valid_scores = interp_scores[start_idx:end_idx] # 处理可能的空值(理论上插值后不会出现) valid_scores = valid_scores[~np.isnan(valid_scores)] scene_score = valid_scores.mean() if len(valid_scores) > 0 else 0.0 # 保存片段信息 scene_info = { "start": convert_time(start_sec), "end": convert_time(end_sec), "score": round(float(scene_score), 3), "start_sec": start_sec, "end_sec": end_sec } processed_scenes.append(scene_info) # 按得分排序 processed_scenes.sort(key=lambda x: x['score'], reverse=True) # 生成输出内容 timecodes = [{"title": filename + ".mp4", "fps": scene_list[0][0].get_framerate()}] shots = [] stills = [] for idx, scene in enumerate(processed_scenes): # 生成片段名称 shot_name = f"shot_{idx+1}_{filename}" target_name = os.path.join(output_dir, f"{shot_name}.mp4") # 分割视频片段 with VideoFileClip(video_path) as clip: subclip = clip.subclip(scene['start_sec'], scene['end_sec']) subclip.write_videofile(target_name, codec="libx264", audio_codec="aac", threads=4, preset="fast", ffmpeg_params=["-crf", "23"]) # 截取缩略图 vid = cv2.VideoCapture(video_path) vid.set(cv2.CAP_PROP_POS_MSEC, scene['start_sec']*1000) ret, frame = vid.read() img_path = os.path.join(output_dir, f"{shot_name}_screenshot.png") cv2.imwrite(img_path, frame) vid.release() # 保存结果(带片段名称的标签) timecodes.append({ "tc_in": scene['start'], "tc_out": scene['end'], "score": scene['score'], "shot_name": shot_name }) shots.append(target_name) stills.append((img_path, f'{shot_name}\nScore: {scene["score"]:.3f}')) # 生成折线图数据(使用插值后的数据) plot_data = pd.DataFrame({ 'x': new_time, 'y': interp_scores }) return timecodes, shots, stills, plot_data # 修改后的界面 with gr.Blocks() as demo: with gr.Column(): gr.Markdown(""" # 增强版场景编辑检测 新增功能: 1. 输入文本查询分析视频内容相关性 2. 显示相关性时序折线图 3. 按相关性得分排序输出片段 """) with gr.Row(): with gr.Column(): video_input = gr.Video(sources="upload", format="mp4", label="视频输入") query_input = gr.Textbox(label="文本查询", placeholder="输入描述视频内容的文本(5-15单词为佳)") threshold = gr.Slider(label="场景切换检测阈值", minimum=15.0, maximum=40.0, value=27.0) with gr.Row(): clear_button = gr.Button("清除") run_button = gr.Button("开始处理", variant="primary") plot_output = gr.LinePlot(x='x', y='y', x_title='时间(秒)', y_title='相关性得分', label='时序相关性分析') with gr.Column(): json_output = gr.JSON(label="场景分析结果(按得分排序)") file_output = gr.File(label="分割片段下载") gallery_output = gr.Gallery(label="场景缩略图", object_fit="cover", columns=3) run_button.click( fn=find_scenes, inputs=[video_input, threshold, query_input], outputs=[json_output, file_output, gallery_output, plot_output] ) clear_button.click( fn=lambda: [None, 27, None, None, None, None], inputs=None, outputs=[video_input, threshold, query_input, json_output, file_output, gallery_output] ) gr.Examples( examples=[ ["anime_kiss.mp4", 27, "A romantic kiss scene between two characters"], ["anime_tear.mp4", 30, "An anime character is crying."] ], inputs=[video_input, threshold, query_input], outputs=[json_output, file_output, gallery_output, plot_output], fn=find_scenes, cache_examples=False ) demo.queue().launch(debug=True, share=True)