Spaces:
Sleeping
Sleeping
Create convert.py
Browse files- convert.py +190 -0
convert.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import av
|
3 |
+
import pims
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from typing import Optional, Tuple
|
8 |
+
from torchvision import transforms
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from PIL import Image
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
|
15 |
+
class VideoReader(Dataset):
|
16 |
+
def __init__(self, path, transform=None):
|
17 |
+
self.video = pims.PyAVVideoReader(path)
|
18 |
+
self.rate = self.video.frame_rate
|
19 |
+
self.transform = transform
|
20 |
+
|
21 |
+
@property
|
22 |
+
def frame_rate(self):
|
23 |
+
return self.rate
|
24 |
+
|
25 |
+
def __len__(self):
|
26 |
+
return len(self.video)
|
27 |
+
|
28 |
+
def __getitem__(self, idx):
|
29 |
+
frame = self.video[idx]
|
30 |
+
frame = Image.fromarray(np.asarray(frame))
|
31 |
+
if self.transform is not None:
|
32 |
+
frame = self.transform(frame)
|
33 |
+
return frame
|
34 |
+
|
35 |
+
|
36 |
+
class VideoWriter:
|
37 |
+
def __init__(self, path, frame_rate, bit_rate=1000000):
|
38 |
+
self.container = av.open(path, mode="w")
|
39 |
+
self.stream = self.container.add_stream("h264", rate=f"{frame_rate:.4f}")
|
40 |
+
self.stream.pix_fmt = "yuv420p"
|
41 |
+
self.stream.bit_rate = bit_rate
|
42 |
+
|
43 |
+
def write(self, frames):
|
44 |
+
# frames: [T, C, H, W]
|
45 |
+
self.stream.width = frames.size(3)
|
46 |
+
self.stream.height = frames.size(2)
|
47 |
+
if frames.size(1) == 1:
|
48 |
+
frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB
|
49 |
+
frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()
|
50 |
+
for t in range(frames.shape[0]):
|
51 |
+
frame = frames[t]
|
52 |
+
frame = av.VideoFrame.from_ndarray(frame, format="rgb24")
|
53 |
+
self.container.mux(self.stream.encode(frame))
|
54 |
+
|
55 |
+
def close(self):
|
56 |
+
self.container.mux(self.stream.encode())
|
57 |
+
self.container.close()
|
58 |
+
|
59 |
+
|
60 |
+
def auto_downsample_ratio(h, w):
|
61 |
+
"""
|
62 |
+
Automatically find a downsample ratio so that the largest side of the resolution be 512px.
|
63 |
+
"""
|
64 |
+
return min(512 / max(h, w), 1)
|
65 |
+
|
66 |
+
|
67 |
+
def convert_video(
|
68 |
+
model,
|
69 |
+
input_source: str,
|
70 |
+
input_resize: Optional[Tuple[int, int]] = None,
|
71 |
+
downsample_ratio: Optional[float] = None,
|
72 |
+
output_composition: Optional[str] = None,
|
73 |
+
output_alpha: Optional[str] = None,
|
74 |
+
output_foreground: Optional[str] = None,
|
75 |
+
output_video_mbps: Optional[float] = None,
|
76 |
+
seq_chunk: int = 1,
|
77 |
+
num_workers: int = 0,
|
78 |
+
progress: bool = True,
|
79 |
+
device: Optional[str] = None,
|
80 |
+
dtype: Optional[torch.dtype] = None,
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Args:
|
84 |
+
input_source:A video file, or an image sequence directory. Images must be sorted in accending order, support png and jpg.
|
85 |
+
input_resize: If provided, the input are first resized to (w, h).
|
86 |
+
downsample_ratio: The model's downsample_ratio hyperparameter. If not provided, model automatically set one.
|
87 |
+
output_type: Options: ["video", "png_sequence"].
|
88 |
+
output_composition:
|
89 |
+
The composition output path. File path if output_type == 'video'. Directory path if output_type == 'png_sequence'.
|
90 |
+
If output_type == 'video', the composition has green screen background.
|
91 |
+
If output_type == 'png_sequence'. the composition is RGBA png images.
|
92 |
+
output_alpha: The alpha output from the model.
|
93 |
+
output_foreground: The foreground output from the model.
|
94 |
+
seq_chunk: Number of frames to process at once. Increase it for better parallelism.
|
95 |
+
num_workers: PyTorch's DataLoader workers. Only use >0 for image input.
|
96 |
+
progress: Show progress bar.
|
97 |
+
device: Only need to manually provide if model is a TorchScript freezed model.
|
98 |
+
dtype: Only need to manually provide if model is a TorchScript freezed model.
|
99 |
+
"""
|
100 |
+
|
101 |
+
assert downsample_ratio is None or (
|
102 |
+
downsample_ratio > 0 and downsample_ratio <= 1
|
103 |
+
), "Downsample ratio must be between 0 (exclusive) and 1 (inclusive)."
|
104 |
+
assert any(
|
105 |
+
[output_composition, output_alpha, output_foreground]
|
106 |
+
), "Must provide at least one output."
|
107 |
+
assert seq_chunk >= 1, "Sequence chunk must be >= 1"
|
108 |
+
assert num_workers >= 0, "Number of workers must be >= 0"
|
109 |
+
|
110 |
+
# Initialize transform
|
111 |
+
if input_resize is not None:
|
112 |
+
transform = transforms.Compose(
|
113 |
+
[transforms.Resize(input_resize[::-1]), transforms.ToTensor()]
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
transform = transforms.ToTensor()
|
117 |
+
|
118 |
+
# Initialize reader
|
119 |
+
source = VideoReader(input_source, transform)
|
120 |
+
reader = DataLoader(
|
121 |
+
source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers
|
122 |
+
)
|
123 |
+
|
124 |
+
# Initialize writers
|
125 |
+
frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30
|
126 |
+
output_video_mbps = 1 if output_video_mbps is None else output_video_mbps
|
127 |
+
if output_composition is not None:
|
128 |
+
writer_com = VideoWriter(
|
129 |
+
path=output_composition,
|
130 |
+
frame_rate=frame_rate,
|
131 |
+
bit_rate=int(output_video_mbps * 1000000),
|
132 |
+
)
|
133 |
+
if output_alpha is not None:
|
134 |
+
writer_pha = VideoWriter(
|
135 |
+
path=output_alpha,
|
136 |
+
frame_rate=frame_rate,
|
137 |
+
bit_rate=int(output_video_mbps * 1000000),
|
138 |
+
)
|
139 |
+
if output_foreground is not None:
|
140 |
+
writer_fgr = VideoWriter(
|
141 |
+
path=output_foreground,
|
142 |
+
frame_rate=frame_rate,
|
143 |
+
bit_rate=int(output_video_mbps * 1000000),
|
144 |
+
)
|
145 |
+
|
146 |
+
# Inference
|
147 |
+
model = model.eval()
|
148 |
+
if device is None or dtype is None:
|
149 |
+
param = next(model.parameters())
|
150 |
+
dtype = param.dtype
|
151 |
+
device = param.device
|
152 |
+
|
153 |
+
if output_composition is not None:
|
154 |
+
bgr = (
|
155 |
+
torch.tensor([0, 0, 0], device=device, dtype=dtype)
|
156 |
+
.div(255)
|
157 |
+
.view(1, 1, 3, 1, 1)
|
158 |
+
)
|
159 |
+
|
160 |
+
try:
|
161 |
+
with torch.no_grad():
|
162 |
+
bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True)
|
163 |
+
rec = [None] * 4
|
164 |
+
for src in reader:
|
165 |
+
if downsample_ratio is None:
|
166 |
+
downsample_ratio = auto_downsample_ratio(*src.shape[2:])
|
167 |
+
|
168 |
+
src = src.to(device, dtype, non_blocking=True).unsqueeze(
|
169 |
+
0
|
170 |
+
) # [B, T, C, H, W]
|
171 |
+
fgr, pha, *rec = model(src, *rec, downsample_ratio)
|
172 |
+
|
173 |
+
if output_foreground is not None:
|
174 |
+
writer_fgr.write(fgr[0])
|
175 |
+
if output_alpha is not None:
|
176 |
+
writer_pha.write(pha[0])
|
177 |
+
if output_composition is not None:
|
178 |
+
com = fgr * pha + bgr * (1 - pha)
|
179 |
+
writer_com.write(com[0])
|
180 |
+
|
181 |
+
bar.update(src.size(1))
|
182 |
+
|
183 |
+
finally:
|
184 |
+
# Clean up
|
185 |
+
if output_composition is not None:
|
186 |
+
writer_com.close()
|
187 |
+
if output_alpha is not None:
|
188 |
+
writer_pha.close()
|
189 |
+
if output_foreground is not None:
|
190 |
+
writer_fgr.close()
|