Movie_Video_Colorization / colorizer.py
davidpengg's picture
init
fa26127
import torch
from torchvision import transforms
import numpy as np
from skimage.color import rgb2lab, lab2rgb
import skimage.transform
from PIL import Image
import os
from tqdm import tqdm
from moviepy.editor import VideoFileClip, AudioFileClip
from moviepy.tools import cvsecs
import cv2
from pdb import set_trace
def lab_to_rgb(L, ab):
"""
Takes a batch of images
"""
L = (L + 1.) * 50.
ab = ab * 110.
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
rgb_imgs = []
for img in Lab:
img_rgb = lab2rgb(img)
rgb_imgs.append(img_rgb)
return np.stack(rgb_imgs, axis=0)
SIZE = 256
def get_L(img):
img = transforms.Resize(
(SIZE, SIZE), transforms.InterpolationMode.BICUBIC)(img)
img = np.array(img)
img_lab = rgb2lab(img).astype("float32")
img_lab = transforms.ToTensor()(img_lab)
L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
return L
def get_predictions(model, L):
# model.L = L.to(model.device)
model.eval()
with torch.no_grad():
model.L = L.to(torch.device('cpu'))
model.forward()
fake_color = model.fake_color.detach()
fake_imgs = lab_to_rgb(L, fake_color)
return fake_imgs
def colorize_img(model, img):
L = get_L(img)
L = L[None] # put in list
fake_imgs = get_predictions(model, L)
fake_img = fake_imgs[0] # get out of list
resized_fake_img = skimage.transform.resize(
fake_img, img.size[::-1]) # reshape to original size
return resized_fake_img
def valid_start_end(duration, start_input, end_input):
start = start_input
end = end_input
if start == '':
start = 0
if end == '':
end = duration
try:
start = cvsecs(start)
end = cvsecs(end)
except BaseException:
# start, end aren't actual time values.
raise Exception("Invalid start, end values")
# make it minimal maximum length
start = max(start, 0)
end = min(duration, end)
# start must be less than end
if start >= end:
raise Exception("Start must be before end.")
return start, end
def colorize_vid(path_input, model, fps, start_input, end_input):
original_video = VideoFileClip(path_input)
# validate start, end
start, end = valid_start_end(
original_video.duration, start_input, end_input)
input_video = original_video.subclip(start, end)
if isinstance(fps, int):
used_fps = fps
nframes = np.round(fps * input_video.duration)
else:
used_fps = input_video.fps
nframes = input_video.reader.nframes
print(
f"Colorizing output with FPS: {fps}, nframes: {nframes}, resolution: {input_video.size}.")
frames = input_video.iter_frames(fps=used_fps)
# create tmp path that is same as input path but with '_tmp.[suffix]'
base_path, suffix = os.path.splitext(path_input)
path_video_tmp = base_path + "_tmp" + suffix
# create video writer for output
size = input_video.size
out = cv2.VideoWriter(
path_video_tmp,
cv2.VideoWriter_fourcc(
*'mp4v'),
used_fps,
size)
# out = cv2.VideoWriter(path_video_tmp, cv2.VideoWriter_fourcc(*'DIVX'), used_fps, size)
for frame in tqdm(frames, total=nframes):
# get colorized frame
color_frame = colorize_img(model, Image.fromarray(frame))
if color_frame.max() <= 1:
color_frame = (color_frame * 255).astype(np.uint8)
color_frame = cv2.cvtColor(color_frame, cv2.COLOR_BGR2RGB)
out.write(color_frame)
out.release()
# create output path that is same as input path but with '_out.[suffix]'
path_output = base_path + "_out" + suffix
# for some reason, subclip doesn't save audio. so make tmp audio file
path_audio_tmp = base_path + "audio_tmp.mp3"
input_video.audio.write_audiofile(path_audio_tmp, logger=None)
input_audio = AudioFileClip(path_audio_tmp)
output_video = VideoFileClip(path_video_tmp)
output_video = output_video.set_audio(input_audio)
output_video.write_videofile(path_output, logger=None)
os.remove(path_video_tmp)
os.remove(path_audio_tmp)
print("Done.")
return path_output