File size: 4,707 Bytes
9915c5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33ac1d8
9915c5d
 
33ac1d8
 
9915c5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
import subprocess

import yaml
from tqdm import tqdm

import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
from sync_batchnorm import DataParallelWithCallback

from modules.generator import OcclusionAwareGenerator
from modules.keypoint_detector import KPDetector
from animate import normalize_kp


def load_checkpoints(config_path, checkpoint_path, cpu=False):

    with open(config_path) as f:
        config = yaml.load(f)

    generator = OcclusionAwareGenerator(
        **config["model_params"]["generator_params"], **config["model_params"]["common_params"]
    )
    if not cpu:
        generator.cuda()

    kp_detector = KPDetector(**config["model_params"]["kp_detector_params"], **config["model_params"]["common_params"])
    if not cpu:
        kp_detector.cuda()

    if cpu:
        checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu"))
    else:
        checkpoint = torch.load(checkpoint_path)

    generator.load_state_dict(checkpoint["generator"])
    kp_detector.load_state_dict(checkpoint["kp_detector"])

    if not cpu:
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    return generator, kp_detector


def make_animation(
    source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False
):
    with torch.no_grad():
        predictions = []
        source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
        if not cpu:
            source = source.cuda()
        driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)
        kp_source = kp_detector(source)
        kp_driving_initial = kp_detector(driving[:, :, 0])

        for frame_idx in tqdm(range(driving.shape[2])):
            driving_frame = driving[:, :, frame_idx]
            if not cpu:
                driving_frame = driving_frame.cuda()
            kp_driving = kp_detector(driving_frame)
            kp_norm = normalize_kp(
                kp_source=kp_source,
                kp_driving=kp_driving,
                kp_driving_initial=kp_driving_initial,
                use_relative_movement=relative,
                use_relative_jacobian=relative,
                adapt_movement_scale=adapt_movement_scale,
            )
            out = generator(source, kp_source=kp_source, kp_driving=kp_norm)

            predictions.append(np.transpose(out["prediction"].data.cpu().numpy(), [0, 2, 3, 1])[0])
    return predictions


def inference(video, image):
    source_image = imageio.imread(image)
    reader = imageio.get_reader(video)
    fps = reader.get_meta_data()["fps"]
    driving_video = []
    try:
        for im in reader:
            driving_video.append(im)
    except RuntimeError:
        pass
    reader.close()

    source_image = resize(source_image, (256, 256))[..., :3]
    driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]

    predictions = make_animation(
        source_image,
        driving_video,
        generator,
        kp_detector,
        relative=True,
        adapt_movement_scale=True,
        cpu=True,
    )
    imageio.mimsave("result.mp4", [img_as_ubyte(frame) for frame in predictions], fps=fps)
    imageio.mimsave("driving.mp4", [img_as_ubyte(frame) for frame in driving_video], fps=fps)
    cmd = f"ffmpeg -y -i result.mp4 -i {video} -c copy -map 0:0 -map 1:1 -shortest out.mp4"
    subprocess.run(cmd.split())
    cmd = "ffmpeg -y -i driving.mp4 -i out.mp4 -filter_complex hstack=inputs=2 final.mp4"
    return "final.mp4"


title = "First Order Motion Model"
description = "Gradio demo for First Order Motion Model. Read more at the links below."
article = "<p style='text-align: center'><a href='https://papers.nips.cc/paper/2019/file/31c0b36aef265d9221af80872ceb62f9-Paper.pdf' target='_blank'>First Order Motion Model for Image Animation</a> | <a href='https://github.com/AliaksandrSiarohin/first-order-model' target='_blank'>Github Repo</a></p>"
examples = [["bella_porch.mp4", "julien.png"]]
generator, kp_detector = load_checkpoints(
    config_path="config/vox-256.yaml",
    checkpoint_path="weights/vox-adv-cpk.pth.tar",
    cpu=True,
)

iface = gr.Interface(
    inference,
    [
        gr.inputs.Video(type="mp4"),
        gr.inputs.Image(type="filepath"),
    ],
    outputs=gr.outputs.Video(label="Output Video"),
    examples=examples,
    enable_queue=True,
    title=title,
    article=article,
    description=description,
    server_name="0.0.0.0",
)
iface.launch(debug=True)