|
|
|
from json import load |
|
import os |
|
import cv2 |
|
import sys |
|
import glob |
|
import torch |
|
import argparse |
|
import numpy as np |
|
import os.path as osp |
|
from warnings import warn |
|
from omegaconf import OmegaConf |
|
from torchvision.utils import make_grid |
|
sys.path.append('.') |
|
from utils.utils import ( |
|
read, write, |
|
img2tensor, tensor2img, |
|
check_dim_and_resize |
|
) |
|
from utils.build_utils import build_from_cfg |
|
from utils.utils import InputPadder |
|
|
|
|
|
AMT_G = { |
|
'name': 'networks.AMT-G.Model', |
|
'params':{ |
|
'corr_radius': 3, |
|
'corr_lvls': 4, |
|
'num_flows': 5, |
|
} |
|
} |
|
|
|
|
|
|
|
def init(device="cuda"): |
|
|
|
''' |
|
initialize the device and the anchor resolution. |
|
''' |
|
|
|
if device == 'cuda': |
|
anchor_resolution = 1024 * 512 |
|
anchor_memory = 1500 * 1024**2 |
|
anchor_memory_bias = 2500 * 1024**2 |
|
vram_avail = torch.cuda.get_device_properties(device).total_memory |
|
print("VRAM available: {:.1f} MB".format(vram_avail / 1024 ** 2)) |
|
else: |
|
|
|
anchor_resolution = 8192*8192 |
|
anchor_memory = 1 |
|
anchor_memory_bias = 0 |
|
vram_avail = 1 |
|
|
|
return anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail |
|
|
|
def get_input_video_from_path(input_path, device="cuda"): |
|
|
|
''' |
|
Get the input video from the input_path. |
|
|
|
params: |
|
input_path: str, the path of the input video. |
|
devices: str, the device to run the model. |
|
returns: |
|
inputs: list, the list of the input frames. |
|
scale: float, the scale of the input frames. |
|
padder: InputPadder, the padder to pad the input frames. |
|
''' |
|
|
|
anchor_resolution, anchor_memory, anchor_memory_bias, vram_avail = init(device) |
|
|
|
if osp.splitext(input_path)[-1] in ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv', |
|
'.webm', '.MP4', '.AVI', '.MOV', '.MKV', '.FLV', |
|
'.WMV', '.WEBM']: |
|
|
|
vcap = cv2.VideoCapture(input_path) |
|
|
|
inputs = [] |
|
w = int(vcap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
h = int(vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
scale = anchor_resolution / (h * w) * np.sqrt((vram_avail - anchor_memory_bias) / anchor_memory) |
|
scale = 1 if scale > 1 else scale |
|
scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16 |
|
if scale < 1: |
|
print(f"Due to the limited VRAM, the video will be scaled by {scale:.2f}") |
|
padding = int(16 / scale) |
|
padder = InputPadder((h, w), padding) |
|
while True: |
|
ret, frame = vcap.read() |
|
if ret is False: |
|
break |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frame_t = img2tensor(frame).to(device) |
|
frame_t = padder.pad(frame_t) |
|
inputs.append(frame_t) |
|
print(f'Loading the [video] from {input_path}, the number of frames [{len(inputs)}]') |
|
else: |
|
raise TypeError("Input should be a video.") |
|
|
|
return inputs, scale, padder |
|
|
|
|
|
def load_model(ckpt_path, device="cuda"): |
|
|
|
''' |
|
load the frame interpolation model. |
|
''' |
|
network_cfg = AMT_G |
|
network_name = network_cfg['name'] |
|
print(f'Loading [{network_name}] from [{ckpt_path}]...') |
|
model = build_from_cfg(network_cfg) |
|
ckpt = torch.load(ckpt_path) |
|
model.load_state_dict(ckpt['state_dict']) |
|
model = model.to(device) |
|
model.eval() |
|
return model |
|
|
|
def interpolater(model, inputs, scale, padder, iters=1): |
|
|
|
''' |
|
interpolating with the interpolation model. |
|
|
|
params: |
|
model: nn.Module, the frame interpolation model. |
|
inputs: list, the list of the input frames. |
|
scale: float, the scale of the input frames. |
|
iters: int, the number of iterations of interpolation. The final frames model generating is 2 ** iters * (m - 1) + 1 and m is input frames. |
|
returns: |
|
outputs: list, the list of the output frames. |
|
''' |
|
|
|
print(f'Start frame interpolation:') |
|
embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(device) |
|
|
|
for i in range(iters): |
|
print(f'Iter {i+1}. input_frames={len(inputs)} output_frames={2*len(inputs)-1}') |
|
outputs = [inputs[0]] |
|
for in_0, in_1 in zip(inputs[:-1], inputs[1:]): |
|
in_0 = in_0.to(device) |
|
in_1 = in_1.to(device) |
|
with torch.no_grad(): |
|
imgt_pred = model(in_0, in_1, embt, scale_factor=scale, eval=True)['imgt_pred'] |
|
outputs += [imgt_pred.cpu(), in_1.cpu()] |
|
inputs = outputs |
|
|
|
outputs = padder.unpad(*outputs) |
|
|
|
return outputs |
|
|
|
def write(outputs, input_path, output_path, frame_rate=30): |
|
''' |
|
write results to the output_path. |
|
''' |
|
|
|
if osp.exists(output_path) is False: |
|
os.makedirs(output_path) |
|
|
|
|
|
size = outputs[0].shape[2:][::-1] |
|
|
|
_, file_name_with_extension = os.path.split(input_path) |
|
file_name, _ = os.path.splitext(file_name_with_extension) |
|
|
|
save_video_path = f'{output_path}/output_{file_name}.mp4' |
|
writer = cv2.VideoWriter(save_video_path, cv2.VideoWriter_fourcc(*"mp4v"), |
|
frame_rate, size) |
|
|
|
for i, imgt_pred in enumerate(outputs): |
|
imgt_pred = tensor2img(imgt_pred) |
|
imgt_pred = cv2.cvtColor(imgt_pred, cv2.COLOR_RGB2BGR) |
|
writer.write(imgt_pred) |
|
print(f"Demo video is saved to [{save_video_path}]") |
|
|
|
writer.release() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--ckpt', type=str, default='amt-g.pth', help="The pretrained model.") |
|
parser.add_argument('--niters', type=int, default=1, help="Iter of Interpolation. The number of frames will be double after per iter.") |
|
parser.add_argument('--input', default="test.mp4", help="Input video.") |
|
parser.add_argument('--output_path', type=str, default='results', help="Output path.") |
|
parser.add_argument('--frame_rate', type=int, default=30, help="Frames rate of the output video.") |
|
|
|
args = parser.parse_args() |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
ckpt_path = args.ckpt |
|
input_path = args.input |
|
output_path = args.output_path |
|
iters = int(args.niters) |
|
frame_rate = int(args.frame_rate) |
|
|
|
inputs, scale, padder = get_input_video_from_path(input_path, device) |
|
model = load_model(ckpt_path, device) |
|
outputs = interpolater(model, inputs, scale, padder, iters) |
|
write(outputs, input_path, output_path, frame_rate) |
|
|