sanghan commited on
Commit
261ef4c
·
1 Parent(s): 1c89dfa

Create convert.py

Browse files
Files changed (1) hide show
  1. convert.py +190 -0
convert.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import av
3
+ import pims
4
+
5
+ import numpy as np
6
+
7
+ from typing import Optional, Tuple
8
+ from torchvision import transforms
9
+ from torch.utils.data import Dataset
10
+ from PIL import Image
11
+ from torch.utils.data import DataLoader
12
+ from tqdm import tqdm
13
+
14
+
15
+ class VideoReader(Dataset):
16
+ def __init__(self, path, transform=None):
17
+ self.video = pims.PyAVVideoReader(path)
18
+ self.rate = self.video.frame_rate
19
+ self.transform = transform
20
+
21
+ @property
22
+ def frame_rate(self):
23
+ return self.rate
24
+
25
+ def __len__(self):
26
+ return len(self.video)
27
+
28
+ def __getitem__(self, idx):
29
+ frame = self.video[idx]
30
+ frame = Image.fromarray(np.asarray(frame))
31
+ if self.transform is not None:
32
+ frame = self.transform(frame)
33
+ return frame
34
+
35
+
36
+ class VideoWriter:
37
+ def __init__(self, path, frame_rate, bit_rate=1000000):
38
+ self.container = av.open(path, mode="w")
39
+ self.stream = self.container.add_stream("h264", rate=f"{frame_rate:.4f}")
40
+ self.stream.pix_fmt = "yuv420p"
41
+ self.stream.bit_rate = bit_rate
42
+
43
+ def write(self, frames):
44
+ # frames: [T, C, H, W]
45
+ self.stream.width = frames.size(3)
46
+ self.stream.height = frames.size(2)
47
+ if frames.size(1) == 1:
48
+ frames = frames.repeat(1, 3, 1, 1) # convert grayscale to RGB
49
+ frames = frames.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()
50
+ for t in range(frames.shape[0]):
51
+ frame = frames[t]
52
+ frame = av.VideoFrame.from_ndarray(frame, format="rgb24")
53
+ self.container.mux(self.stream.encode(frame))
54
+
55
+ def close(self):
56
+ self.container.mux(self.stream.encode())
57
+ self.container.close()
58
+
59
+
60
+ def auto_downsample_ratio(h, w):
61
+ """
62
+ Automatically find a downsample ratio so that the largest side of the resolution be 512px.
63
+ """
64
+ return min(512 / max(h, w), 1)
65
+
66
+
67
+ def convert_video(
68
+ model,
69
+ input_source: str,
70
+ input_resize: Optional[Tuple[int, int]] = None,
71
+ downsample_ratio: Optional[float] = None,
72
+ output_composition: Optional[str] = None,
73
+ output_alpha: Optional[str] = None,
74
+ output_foreground: Optional[str] = None,
75
+ output_video_mbps: Optional[float] = None,
76
+ seq_chunk: int = 1,
77
+ num_workers: int = 0,
78
+ progress: bool = True,
79
+ device: Optional[str] = None,
80
+ dtype: Optional[torch.dtype] = None,
81
+ ):
82
+ """
83
+ Args:
84
+ input_source:A video file, or an image sequence directory. Images must be sorted in accending order, support png and jpg.
85
+ input_resize: If provided, the input are first resized to (w, h).
86
+ downsample_ratio: The model's downsample_ratio hyperparameter. If not provided, model automatically set one.
87
+ output_type: Options: ["video", "png_sequence"].
88
+ output_composition:
89
+ The composition output path. File path if output_type == 'video'. Directory path if output_type == 'png_sequence'.
90
+ If output_type == 'video', the composition has green screen background.
91
+ If output_type == 'png_sequence'. the composition is RGBA png images.
92
+ output_alpha: The alpha output from the model.
93
+ output_foreground: The foreground output from the model.
94
+ seq_chunk: Number of frames to process at once. Increase it for better parallelism.
95
+ num_workers: PyTorch's DataLoader workers. Only use >0 for image input.
96
+ progress: Show progress bar.
97
+ device: Only need to manually provide if model is a TorchScript freezed model.
98
+ dtype: Only need to manually provide if model is a TorchScript freezed model.
99
+ """
100
+
101
+ assert downsample_ratio is None or (
102
+ downsample_ratio > 0 and downsample_ratio <= 1
103
+ ), "Downsample ratio must be between 0 (exclusive) and 1 (inclusive)."
104
+ assert any(
105
+ [output_composition, output_alpha, output_foreground]
106
+ ), "Must provide at least one output."
107
+ assert seq_chunk >= 1, "Sequence chunk must be >= 1"
108
+ assert num_workers >= 0, "Number of workers must be >= 0"
109
+
110
+ # Initialize transform
111
+ if input_resize is not None:
112
+ transform = transforms.Compose(
113
+ [transforms.Resize(input_resize[::-1]), transforms.ToTensor()]
114
+ )
115
+ else:
116
+ transform = transforms.ToTensor()
117
+
118
+ # Initialize reader
119
+ source = VideoReader(input_source, transform)
120
+ reader = DataLoader(
121
+ source, batch_size=seq_chunk, pin_memory=True, num_workers=num_workers
122
+ )
123
+
124
+ # Initialize writers
125
+ frame_rate = source.frame_rate if isinstance(source, VideoReader) else 30
126
+ output_video_mbps = 1 if output_video_mbps is None else output_video_mbps
127
+ if output_composition is not None:
128
+ writer_com = VideoWriter(
129
+ path=output_composition,
130
+ frame_rate=frame_rate,
131
+ bit_rate=int(output_video_mbps * 1000000),
132
+ )
133
+ if output_alpha is not None:
134
+ writer_pha = VideoWriter(
135
+ path=output_alpha,
136
+ frame_rate=frame_rate,
137
+ bit_rate=int(output_video_mbps * 1000000),
138
+ )
139
+ if output_foreground is not None:
140
+ writer_fgr = VideoWriter(
141
+ path=output_foreground,
142
+ frame_rate=frame_rate,
143
+ bit_rate=int(output_video_mbps * 1000000),
144
+ )
145
+
146
+ # Inference
147
+ model = model.eval()
148
+ if device is None or dtype is None:
149
+ param = next(model.parameters())
150
+ dtype = param.dtype
151
+ device = param.device
152
+
153
+ if output_composition is not None:
154
+ bgr = (
155
+ torch.tensor([0, 0, 0], device=device, dtype=dtype)
156
+ .div(255)
157
+ .view(1, 1, 3, 1, 1)
158
+ )
159
+
160
+ try:
161
+ with torch.no_grad():
162
+ bar = tqdm(total=len(source), disable=not progress, dynamic_ncols=True)
163
+ rec = [None] * 4
164
+ for src in reader:
165
+ if downsample_ratio is None:
166
+ downsample_ratio = auto_downsample_ratio(*src.shape[2:])
167
+
168
+ src = src.to(device, dtype, non_blocking=True).unsqueeze(
169
+ 0
170
+ ) # [B, T, C, H, W]
171
+ fgr, pha, *rec = model(src, *rec, downsample_ratio)
172
+
173
+ if output_foreground is not None:
174
+ writer_fgr.write(fgr[0])
175
+ if output_alpha is not None:
176
+ writer_pha.write(pha[0])
177
+ if output_composition is not None:
178
+ com = fgr * pha + bgr * (1 - pha)
179
+ writer_com.write(com[0])
180
+
181
+ bar.update(src.size(1))
182
+
183
+ finally:
184
+ # Clean up
185
+ if output_composition is not None:
186
+ writer_com.close()
187
+ if output_alpha is not None:
188
+ writer_pha.close()
189
+ if output_foreground is not None:
190
+ writer_fgr.close()