File size: 2,425 Bytes
a7af99c
 
 
 
 
 
c461c6c
a7af99c
 
 
 
 
 
 
c461c6c
 
a7af99c
 
 
 
 
 
 
 
c461c6c
 
 
 
 
 
 
 
 
 
 
 
a7af99c
c461c6c
a7af99c
 
 
 
 
 
 
 
c461c6c
 
 
 
 
 
4755710
c461c6c
 
 
69924a7
 
 
c461c6c
 
 
69924a7
c461c6c
a7af99c
 
 
 
 
c461c6c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python

import os
import pathlib
import tempfile

import cv2
import gradio as gr
import torch
from huggingface_hub import snapshot_download
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline

DESCRIPTION = "# ModelScope-Vid2Vid-XL"
if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>"

if torch.cuda.is_available():
    model_cache_dir = os.getenv("MODEL_CACHE_DIR", "./models")
    model_dir = pathlib.Path(model_cache_dir) / "MS-Vid2Vid-XL"
    snapshot_download(repo_id="damo-vilab/MS-Vid2Vid-XL", repo_type="model", local_dir=model_dir)
    pipe = pipeline(task="video-to-video", model=model_dir.as_posix(), model_revision="v1.1.0", device="cuda:0")


def check_input_video(video_path: str) -> None:
    cap = cv2.VideoCapture(video_path)
    n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    cap.release()
    if n_frames != 32 or width != 448 or height != 256:
        raise gr.Error(
            f"Input video must be 32 frames of size 448x256. Your video is {n_frames} frames of size {width}x{height}."
        )


def video_to_video(video_path: str, text: str) -> str:
    check_input_video(video_path)
    p_input = {"video_path": video_path, "text": text}
    output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
    pipe(p_input, output_video=output_file.name)[OutputKeys.OUTPUT_VIDEO]
    return output_file.name


with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(
        value="Duplicate Space for private use",
        elem_id="duplicate-button",
        visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
    )
    with gr.Group():
        input_video = gr.Video(label="Input video")
        text_description = gr.Textbox(label="Text description")
        run_button = gr.Button()
        output_video = gr.Video(label="Output video")

    gr.on(
        triggers=[text_description.submit, run_button.click],
        fn=check_input_video,
        inputs=input_video,
        queue=False,
        api_name=False,
    ).success(
        fn=video_to_video,
        inputs=[input_video, text_description],
        outputs=output_video,
        api_name="run",
    )

if __name__ == "__main__":
    demo.queue(max_size=10).launch()