hysts's picture
hysts HF staff
Use gr.on
69924a7
raw
history blame
2.46 kB
#!/usr/bin/env python
import os
import pathlib
import tempfile
import cv2
import gradio as gr
import torch
from huggingface_hub import snapshot_download
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
DESCRIPTION = "# ModelScope-Vid2Vid-XL"
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_cache_dir = os.getenv("MODEL_CACHE_DIR", "./models")
model_dir = pathlib.Path(model_cache_dir) / "MS-Vid2Vid-XL"
snapshot_download(repo_id="damo-vilab/MS-Vid2Vid-XL", repo_type="model", local_dir=model_dir)
pipe = pipeline(task="video-to-video", model=model_dir.as_posix(), model_revision="v1.1.0", device="cuda:0")
else:
pipe = None
def check_input_video(video_path: str) -> None:
cap = cv2.VideoCapture(video_path)
n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
cap.release()
if n_frames != 32 or width != 448 or height != 256:
raise gr.Error(
f"Input video must be 32 frames of size 448x256. Your video is {n_frames} frames of size {width}x{height}."
)
def video_to_video(video_path: str, text: str) -> str:
check_input_video(video_path)
p_input = {"video_path": video_path, "text": text}
output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
pipe(p_input, output_video=output_file.name)[OutputKeys.OUTPUT_VIDEO]
return output_file.name
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Group():
input_video = gr.Video(label="Input video", type="filepath")
text_description = gr.Textbox(label="Text description")
run_button = gr.Button()
output_video = gr.Video(label="Output video")
gr.on(
triggers=[text_description.submit, run_button.click],
fn=check_input_video,
inputs=input_video,
queue=False,
api_name=False,
).success(
fn=video_to_video,
inputs=[input_video, text_description],
outputs=output_video,
api_name="run",
)
if __name__ == "__main__":
demo.queue(max_size=10).launch()