Spaces:
Runtime error
Runtime error
davidpengg
commited on
Commit
·
fa26127
1
Parent(s):
56e2a81
init
Browse files- app.py +138 -0
- colorizer.py +157 -0
- dcgan.py +190 -0
- examples/1_falcon.mp4 +3 -0
- examples/2_mughal.mp4 +3 -0
- examples/3_wizard.mp4 +3 -0
- examples/4_elgar.mp4 +3 -0
- modelv1.pth +3 -0
- modelv2.pth +3 -0
- requirements.txt +72 -0
app.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
from pytube import YouTube
|
4 |
+
|
5 |
+
from pdb import set_trace
|
6 |
+
|
7 |
+
from colorizer import colorize_vid
|
8 |
+
from dcgan import *
|
9 |
+
|
10 |
+
# ================================
|
11 |
+
|
12 |
+
# EXAMPLE_FPS = "Same as original"
|
13 |
+
examples = [
|
14 |
+
["examples/1_falcon.mp4", "modelv2", "Same as original"], # 4:21
|
15 |
+
["examples/2_mughal.mp4", "modelv1", 12], # 4:30
|
16 |
+
["examples/3_wizard.mp4", "modelv1", 6], # 7 min
|
17 |
+
["examples/4_elgar.mp4", "modelv2", 6] # 22 min
|
18 |
+
]
|
19 |
+
|
20 |
+
model_choices = [
|
21 |
+
"modelv2",
|
22 |
+
"modelv1",
|
23 |
+
]
|
24 |
+
|
25 |
+
loaded_models = {}
|
26 |
+
for model_weights in model_choices:
|
27 |
+
model = torch.load(f"{model_weights}.pth", map_location=torch.device('cpu'))
|
28 |
+
model.eval() # also done in colorizer
|
29 |
+
loaded_models[model_weights] = model
|
30 |
+
|
31 |
+
|
32 |
+
def colorize_video(path_video, chosen_model, chosen_fps, start='', end=''):
|
33 |
+
if not path_video:
|
34 |
+
return
|
35 |
+
return colorize_vid(
|
36 |
+
path_video,
|
37 |
+
loaded_models[chosen_model],
|
38 |
+
chosen_fps,
|
39 |
+
start,
|
40 |
+
end
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
def download_youtube(url):
|
45 |
+
try:
|
46 |
+
yt = YouTube(url)
|
47 |
+
streams = yt.streams.filter(
|
48 |
+
progressive=True,
|
49 |
+
file_extension='mp4').order_by('resolution')
|
50 |
+
return streams[0].download()
|
51 |
+
except BaseException:
|
52 |
+
raise Exception("Invalid URL or Video Unavailable")
|
53 |
+
|
54 |
+
|
55 |
+
app = gr.Blocks()
|
56 |
+
with app:
|
57 |
+
gr.Markdown("# <p align='center'>Movie and Video Colorization</p>")
|
58 |
+
gr.Markdown(
|
59 |
+
"""
|
60 |
+
<p style='text-align: center'>
|
61 |
+
Colorize black-and-white movies or videos with a DCGAN-based model!
|
62 |
+
<br>
|
63 |
+
Project by David Peng, Annie Lin, Adam Zapatka, and Maggy Lambo.
|
64 |
+
<p>
|
65 |
+
"""
|
66 |
+
)
|
67 |
+
|
68 |
+
gr.Markdown("### Step 1: Choose a YouTube video (or upload locally below)")
|
69 |
+
|
70 |
+
youtube_url = gr.Textbox(label="YouTube Video URL")
|
71 |
+
|
72 |
+
youtube_url_btn = gr.Button(value="Extract YouTube Video")
|
73 |
+
|
74 |
+
with gr.Row():
|
75 |
+
gr.Markdown("### Step 2: Adjust settings")
|
76 |
+
gr.Markdown("### Step 3: Hit \"Colorize\"")
|
77 |
+
with gr.Row():
|
78 |
+
bw_video = gr.Video(label="Black-and-White Video")
|
79 |
+
colorized_video = gr.Video(label="Colorized Video")
|
80 |
+
with gr.Row():
|
81 |
+
with gr.Column():
|
82 |
+
with gr.Row():
|
83 |
+
start_time = gr.Text(
|
84 |
+
label="Start Time (hh:mm:ss or blank for original)", value='')
|
85 |
+
end_time = gr.Text(
|
86 |
+
label="End Time (hh:mm:ss or blank for original)", value='')
|
87 |
+
with gr.Column():
|
88 |
+
bw_video_btn = gr.Button(value="Colorize", variant="primary")
|
89 |
+
with gr.Row():
|
90 |
+
with gr.Column():
|
91 |
+
model_dropdown = gr.Dropdown(
|
92 |
+
model_choices,
|
93 |
+
value=model_choices[0],
|
94 |
+
label="Model"
|
95 |
+
)
|
96 |
+
|
97 |
+
fps_dropdown = gr.Dropdown(
|
98 |
+
[3, 6, 12, 24, 30, "Same as original"],
|
99 |
+
value=6,
|
100 |
+
label="FPS of Colorized Video"
|
101 |
+
)
|
102 |
+
|
103 |
+
gr.Markdown(
|
104 |
+
"""
|
105 |
+
#### Colorization Notes
|
106 |
+
- Leave start, end times blank to colorize the entire video
|
107 |
+
- To lower colorization time, you can decrease FPS, resolution, or duration
|
108 |
+
- *modelv2* tends to color videos orange and sepia
|
109 |
+
- *modelv1* tends to color videos with a variety of colors
|
110 |
+
- *modelv2* and *modelv1* use the same modified DCGAN architecture but differ in results because of randomization in training
|
111 |
+
|
112 |
+
#### More Reading
|
113 |
+
- <a href='https://towardsdatascience.com/colorizing-black-white-images-with-u-net-and-conditional-gan-a-tutorial-81b2df111cd8' target='_blank'>Colorizing black & white images with U-Net and conditional GAN</a>
|
114 |
+
- <a href='https://arxiv.org/abs/1803.05400' target='_blank'>Image Colorization with Generative Adversarial Networks</a>
|
115 |
+
"""
|
116 |
+
)
|
117 |
+
with gr.Column():
|
118 |
+
gr.Examples(
|
119 |
+
examples=examples,
|
120 |
+
inputs=[bw_video, model_dropdown, fps_dropdown],
|
121 |
+
outputs=[colorized_video],
|
122 |
+
fn=colorize_video,
|
123 |
+
cache_examples=True,
|
124 |
+
)
|
125 |
+
|
126 |
+
youtube_url_btn.click(
|
127 |
+
download_youtube,
|
128 |
+
inputs=youtube_url,
|
129 |
+
outputs=bw_video
|
130 |
+
)
|
131 |
+
|
132 |
+
bw_video_btn.click(
|
133 |
+
colorize_video,
|
134 |
+
inputs=[bw_video, model_dropdown, fps_dropdown, start_time, end_time],
|
135 |
+
outputs=colorized_video
|
136 |
+
)
|
137 |
+
|
138 |
+
app.launch()
|
colorizer.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from skimage.color import rgb2lab, lab2rgb
|
6 |
+
import skimage.transform
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
import os
|
10 |
+
from tqdm import tqdm
|
11 |
+
from moviepy.editor import VideoFileClip, AudioFileClip
|
12 |
+
from moviepy.tools import cvsecs
|
13 |
+
import cv2
|
14 |
+
|
15 |
+
from pdb import set_trace
|
16 |
+
|
17 |
+
|
18 |
+
def lab_to_rgb(L, ab):
|
19 |
+
"""
|
20 |
+
Takes a batch of images
|
21 |
+
"""
|
22 |
+
L = (L + 1.) * 50.
|
23 |
+
ab = ab * 110.
|
24 |
+
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
|
25 |
+
rgb_imgs = []
|
26 |
+
for img in Lab:
|
27 |
+
img_rgb = lab2rgb(img)
|
28 |
+
rgb_imgs.append(img_rgb)
|
29 |
+
return np.stack(rgb_imgs, axis=0)
|
30 |
+
|
31 |
+
|
32 |
+
SIZE = 256
|
33 |
+
|
34 |
+
|
35 |
+
def get_L(img):
|
36 |
+
img = transforms.Resize(
|
37 |
+
(SIZE, SIZE), transforms.InterpolationMode.BICUBIC)(img)
|
38 |
+
img = np.array(img)
|
39 |
+
img_lab = rgb2lab(img).astype("float32")
|
40 |
+
img_lab = transforms.ToTensor()(img_lab)
|
41 |
+
L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
|
42 |
+
|
43 |
+
return L
|
44 |
+
|
45 |
+
|
46 |
+
def get_predictions(model, L):
|
47 |
+
# model.L = L.to(model.device)
|
48 |
+
model.eval()
|
49 |
+
with torch.no_grad():
|
50 |
+
model.L = L.to(torch.device('cpu'))
|
51 |
+
model.forward()
|
52 |
+
fake_color = model.fake_color.detach()
|
53 |
+
fake_imgs = lab_to_rgb(L, fake_color)
|
54 |
+
|
55 |
+
return fake_imgs
|
56 |
+
|
57 |
+
|
58 |
+
def colorize_img(model, img):
|
59 |
+
L = get_L(img)
|
60 |
+
L = L[None] # put in list
|
61 |
+
fake_imgs = get_predictions(model, L)
|
62 |
+
fake_img = fake_imgs[0] # get out of list
|
63 |
+
resized_fake_img = skimage.transform.resize(
|
64 |
+
fake_img, img.size[::-1]) # reshape to original size
|
65 |
+
|
66 |
+
return resized_fake_img
|
67 |
+
|
68 |
+
|
69 |
+
def valid_start_end(duration, start_input, end_input):
|
70 |
+
start = start_input
|
71 |
+
end = end_input
|
72 |
+
if start == '':
|
73 |
+
start = 0
|
74 |
+
if end == '':
|
75 |
+
end = duration
|
76 |
+
|
77 |
+
try:
|
78 |
+
start = cvsecs(start)
|
79 |
+
end = cvsecs(end)
|
80 |
+
except BaseException:
|
81 |
+
# start, end aren't actual time values.
|
82 |
+
raise Exception("Invalid start, end values")
|
83 |
+
|
84 |
+
# make it minimal maximum length
|
85 |
+
start = max(start, 0)
|
86 |
+
end = min(duration, end)
|
87 |
+
|
88 |
+
# start must be less than end
|
89 |
+
if start >= end:
|
90 |
+
raise Exception("Start must be before end.")
|
91 |
+
|
92 |
+
return start, end
|
93 |
+
|
94 |
+
|
95 |
+
def colorize_vid(path_input, model, fps, start_input, end_input):
|
96 |
+
|
97 |
+
original_video = VideoFileClip(path_input)
|
98 |
+
|
99 |
+
# validate start, end
|
100 |
+
start, end = valid_start_end(
|
101 |
+
original_video.duration, start_input, end_input)
|
102 |
+
|
103 |
+
input_video = original_video.subclip(start, end)
|
104 |
+
|
105 |
+
if isinstance(fps, int):
|
106 |
+
used_fps = fps
|
107 |
+
nframes = np.round(fps * input_video.duration)
|
108 |
+
else:
|
109 |
+
used_fps = input_video.fps
|
110 |
+
nframes = input_video.reader.nframes
|
111 |
+
print(
|
112 |
+
f"Colorizing output with FPS: {fps}, nframes: {nframes}, resolution: {input_video.size}.")
|
113 |
+
|
114 |
+
frames = input_video.iter_frames(fps=used_fps)
|
115 |
+
|
116 |
+
# create tmp path that is same as input path but with '_tmp.[suffix]'
|
117 |
+
base_path, suffix = os.path.splitext(path_input)
|
118 |
+
path_video_tmp = base_path + "_tmp" + suffix
|
119 |
+
|
120 |
+
# create video writer for output
|
121 |
+
size = input_video.size
|
122 |
+
out = cv2.VideoWriter(
|
123 |
+
path_video_tmp,
|
124 |
+
cv2.VideoWriter_fourcc(
|
125 |
+
*'mp4v'),
|
126 |
+
used_fps,
|
127 |
+
size)
|
128 |
+
# out = cv2.VideoWriter(path_video_tmp, cv2.VideoWriter_fourcc(*'DIVX'), used_fps, size)
|
129 |
+
|
130 |
+
for frame in tqdm(frames, total=nframes):
|
131 |
+
# get colorized frame
|
132 |
+
color_frame = colorize_img(model, Image.fromarray(frame))
|
133 |
+
|
134 |
+
if color_frame.max() <= 1:
|
135 |
+
color_frame = (color_frame * 255).astype(np.uint8)
|
136 |
+
|
137 |
+
color_frame = cv2.cvtColor(color_frame, cv2.COLOR_BGR2RGB)
|
138 |
+
out.write(color_frame)
|
139 |
+
out.release()
|
140 |
+
|
141 |
+
# create output path that is same as input path but with '_out.[suffix]'
|
142 |
+
path_output = base_path + "_out" + suffix
|
143 |
+
|
144 |
+
# for some reason, subclip doesn't save audio. so make tmp audio file
|
145 |
+
path_audio_tmp = base_path + "audio_tmp.mp3"
|
146 |
+
input_video.audio.write_audiofile(path_audio_tmp, logger=None)
|
147 |
+
input_audio = AudioFileClip(path_audio_tmp)
|
148 |
+
|
149 |
+
output_video = VideoFileClip(path_video_tmp)
|
150 |
+
output_video = output_video.set_audio(input_audio)
|
151 |
+
output_video.write_videofile(path_output, logger=None)
|
152 |
+
|
153 |
+
os.remove(path_video_tmp)
|
154 |
+
os.remove(path_audio_tmp)
|
155 |
+
|
156 |
+
print("Done.")
|
157 |
+
return path_output
|
dcgan.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn, optim
|
3 |
+
|
4 |
+
# this architecture is taken from https://github.com/moein-shariatnia/Deep-Learning/tree/main/Image%20Colorization%20Tutorial
|
5 |
+
|
6 |
+
#this is actually the DCGans. in training, we had kept the class name the same as the original to avoid changing code^
|
7 |
+
class Unet(nn.Module):
|
8 |
+
def __init__(self, input_c=1, output_c=2, num_filters=128):
|
9 |
+
super().__init__()
|
10 |
+
self.model = nn.Sequential(
|
11 |
+
nn.Conv2d(input_c,64,kernel_size=4,stride = 1,padding="same"),
|
12 |
+
nn.BatchNorm2d(64),
|
13 |
+
nn.LeakyReLU(0.2, True),
|
14 |
+
nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1),
|
15 |
+
nn.BatchNorm2d(128),
|
16 |
+
nn.LeakyReLU(0.2, True),
|
17 |
+
nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1),
|
18 |
+
nn.BatchNorm2d(256),
|
19 |
+
nn.LeakyReLU(0.2, True),
|
20 |
+
nn.Conv2d(256,256,kernel_size=4,stride=2,padding=1),
|
21 |
+
nn.BatchNorm2d(256),
|
22 |
+
nn.LeakyReLU(0.2, True),
|
23 |
+
nn.Conv2d(256,512,kernel_size=4,stride=2,padding=1),
|
24 |
+
nn.BatchNorm2d(512),
|
25 |
+
nn.LeakyReLU(0.2, True),
|
26 |
+
nn.Conv2d(512,512,kernel_size=4,stride=2,padding=1),
|
27 |
+
nn.BatchNorm2d(512),
|
28 |
+
nn.LeakyReLU(0.2, True),
|
29 |
+
|
30 |
+
nn.ConvTranspose2d(512,512,kernel_size=4,stride=2,padding=1),
|
31 |
+
nn.BatchNorm2d(512),
|
32 |
+
nn.ReLU(True),
|
33 |
+
nn.ConvTranspose2d(512,256,kernel_size=4,stride=2,padding=1),
|
34 |
+
nn.BatchNorm2d(256),
|
35 |
+
nn.ReLU(True),
|
36 |
+
nn.ConvTranspose2d(256,256,kernel_size=4,stride=2,padding=1),
|
37 |
+
nn.BatchNorm2d(256),
|
38 |
+
nn.ReLU(True),
|
39 |
+
nn.ConvTranspose2d(256,128,kernel_size=4,stride=2,padding=1),
|
40 |
+
nn.BatchNorm2d(128),
|
41 |
+
nn.ReLU(True),
|
42 |
+
nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1),
|
43 |
+
nn.BatchNorm2d(64),
|
44 |
+
nn.ReLU(True),
|
45 |
+
nn.Conv2d(64,output_c, kernel_size=1,stride=1),
|
46 |
+
nn.Tanh()
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return self.model(x)
|
51 |
+
class PatchDiscriminator(nn.Module):
|
52 |
+
def __init__(self, input_c, num_filters=64, n_down=3): # num_filters=64
|
53 |
+
super().__init__()
|
54 |
+
model = [self.get_layers(input_c, num_filters, norm=False)]
|
55 |
+
model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
|
56 |
+
for i in range(n_down)] # the 'if' statement is taking care of not using
|
57 |
+
# stride of 2 for the last block in this loop
|
58 |
+
model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] # Make sure to not use normalization or
|
59 |
+
# activation for the last layer of the model
|
60 |
+
self.model = nn.Sequential(*model)
|
61 |
+
|
62 |
+
def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): # when needing to make some repeatitive blocks of layers,
|
63 |
+
layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)] # it's always helpful to make a separate method for that purpose
|
64 |
+
if norm: layers += [nn.BatchNorm2d(nf)]
|
65 |
+
if act: layers += [nn.LeakyReLU(0.2, True)] #nn.LeakyReLU(0.2, True)
|
66 |
+
return nn.Sequential(*layers)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
return self.model(x)
|
70 |
+
|
71 |
+
class GANLoss(nn.Module):
|
72 |
+
def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
|
73 |
+
super().__init__()
|
74 |
+
self.register_buffer('real_label', torch.tensor(real_label))
|
75 |
+
self.register_buffer('fake_label', torch.tensor(fake_label))
|
76 |
+
if gan_mode == 'vanilla':
|
77 |
+
self.loss = nn.BCEWithLogitsLoss()
|
78 |
+
elif gan_mode == 'lsgan':
|
79 |
+
self.loss = nn.MSELoss()
|
80 |
+
|
81 |
+
def get_labels(self, preds, target_is_real):
|
82 |
+
if target_is_real:
|
83 |
+
labels = self.real_label
|
84 |
+
else:
|
85 |
+
labels = self.fake_label
|
86 |
+
return labels.expand_as(preds)
|
87 |
+
|
88 |
+
def __call__(self, preds, target_is_real):
|
89 |
+
labels = self.get_labels(preds, target_is_real)
|
90 |
+
loss = self.loss(preds, labels)
|
91 |
+
return loss
|
92 |
+
|
93 |
+
def init_weights(net, init='norm', gain=0.02):
|
94 |
+
|
95 |
+
def init_func(m):
|
96 |
+
classname = m.__class__.__name__
|
97 |
+
if hasattr(m, 'weight') and 'Conv' in classname:
|
98 |
+
if init == 'norm':
|
99 |
+
nn.init.normal_(m.weight.data, mean=0.0, std=gain)
|
100 |
+
elif init == 'xavier':
|
101 |
+
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
102 |
+
elif init == 'kaiming':
|
103 |
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
104 |
+
|
105 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
106 |
+
nn.init.constant_(m.bias.data, 0.0)
|
107 |
+
elif 'BatchNorm2d' in classname:
|
108 |
+
nn.init.normal_(m.weight.data, 1., gain)
|
109 |
+
nn.init.constant_(m.bias.data, 0.)
|
110 |
+
|
111 |
+
net.apply(init_func)
|
112 |
+
print(f"model initialized with {init} initialization")
|
113 |
+
return net
|
114 |
+
|
115 |
+
def init_model(model, device):
|
116 |
+
model = model.to(device)
|
117 |
+
model = init_weights(model)
|
118 |
+
return model
|
119 |
+
|
120 |
+
class MainModel(nn.Module):
|
121 |
+
def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
|
122 |
+
beta1=0.5, beta2=0.999, lambda_L1=100.):
|
123 |
+
super().__init__()
|
124 |
+
|
125 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
126 |
+
self.lambda_L1 = lambda_L1
|
127 |
+
|
128 |
+
if net_G is None:
|
129 |
+
self.net_G = init_model(Unet(input_c=1, output_c=2, num_filters=64), self.device)
|
130 |
+
else:
|
131 |
+
self.net_G = net_G.to(self.device)
|
132 |
+
self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device)
|
133 |
+
self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
|
134 |
+
self.L1criterion = nn.L1Loss()
|
135 |
+
self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2))
|
136 |
+
self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2))
|
137 |
+
|
138 |
+
def set_requires_grad(self, model, requires_grad=True):
|
139 |
+
for p in model.parameters():
|
140 |
+
p.requires_grad = requires_grad
|
141 |
+
|
142 |
+
def setup_input(self, data):
|
143 |
+
self.L = data['L'].to(self.device)
|
144 |
+
self.ab = data['ab'].to(self.device)
|
145 |
+
|
146 |
+
def forward(self):
|
147 |
+
self.fake_color = self.net_G(self.L)
|
148 |
+
|
149 |
+
def backward_D(self,epoch):
|
150 |
+
fake_image = torch.cat([self.L, self.fake_color], dim=1)
|
151 |
+
fake_preds = self.net_D(fake_image.detach())
|
152 |
+
self.loss_D_fake = self.GANcriterion(fake_preds, False)
|
153 |
+
real_image = torch.cat([self.L, self.ab], dim=1)
|
154 |
+
real_preds = self.net_D(real_image)
|
155 |
+
self.loss_D_real = self.GANcriterion(real_preds, True)
|
156 |
+
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
|
157 |
+
# offset discriminator training
|
158 |
+
if epoch % 2 ==0:
|
159 |
+
self.loss_D.backward()
|
160 |
+
|
161 |
+
def backward_G(self):
|
162 |
+
fake_image = torch.cat([self.L, self.fake_color], dim=1)
|
163 |
+
fake_preds = self.net_D(fake_image)
|
164 |
+
self.loss_G_GAN = self.GANcriterion(fake_preds, True)
|
165 |
+
self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1
|
166 |
+
self.loss_G = self.loss_G_GAN + self.loss_G_L1
|
167 |
+
self.loss_G.backward()
|
168 |
+
|
169 |
+
def optimize(self, epoch):
|
170 |
+
self.forward()
|
171 |
+
self.net_D.train()
|
172 |
+
self.set_requires_grad(self.net_D, True)
|
173 |
+
self.opt_D.zero_grad()
|
174 |
+
self.backward_D(epoch)
|
175 |
+
if epoch % 2 ==0:
|
176 |
+
self.opt_D.step()
|
177 |
+
|
178 |
+
self.net_G.train()
|
179 |
+
self.set_requires_grad(self.net_D, False)
|
180 |
+
self.opt_G.zero_grad()
|
181 |
+
self.backward_G()
|
182 |
+
self.opt_G.step()
|
183 |
+
|
184 |
+
# with torch.no_grad():
|
185 |
+
# model = MainModel()
|
186 |
+
# set_trace()
|
187 |
+
# # model = torch.load("modelbatchv2.pth", map_location=device)
|
188 |
+
# model.load_state_dict(torch.load("modelbatchv2.pth", map_location=torch.device('cpu')).state_dict())
|
189 |
+
# assert model.device.type == "cpu"
|
190 |
+
# model.eval()
|
examples/1_falcon.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:376f3ae4689b715739276f491f089445cd552031aacbf274d32665e64a4fb188
|
3 |
+
size 970446
|
examples/2_mughal.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:70ef48d7bfd3a9c0754e320d16845520c5da89a84a3feaa7dde375cea6d2af37
|
3 |
+
size 55261144
|
examples/3_wizard.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a76046df396bb5ced86ac16de3ed56fc227ba08b3ec7e2552ee56c2336021f84
|
3 |
+
size 6026610
|
examples/4_elgar.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de8fb047d3d251e5af3f5e003601b7f8e23f6ebf327812370656ae304d0b645b
|
3 |
+
size 19892928
|
modelv1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f52e81eafae487dd13bc0b51193c871e55f7fa7d045baf1b3aeede6ce5e1dbee
|
3 |
+
size 221098733
|
modelv2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e8a4aa8fbb7e85d01e433bd7633eedb669c9ac51f434df07957b9157d366e3cf
|
3 |
+
size 246295842
|
requirements.txt
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiohttp==3.8.3
|
2 |
+
aiosignal==1.3.1
|
3 |
+
altair==4.2.0
|
4 |
+
anyio==3.6.2
|
5 |
+
async-timeout==4.0.2
|
6 |
+
attrs==22.2.0
|
7 |
+
charset-normalizer==2.1.1
|
8 |
+
click==8.1.3
|
9 |
+
contourpy==1.0.6
|
10 |
+
cycler==0.11.0
|
11 |
+
decorator==4.4.2
|
12 |
+
entrypoints==0.4
|
13 |
+
fastapi==0.88.0
|
14 |
+
ffmpy==0.3.0
|
15 |
+
fonttools==4.38.0
|
16 |
+
frozenlist==1.3.3
|
17 |
+
fsspec==2022.11.0
|
18 |
+
gradio==3.15.0
|
19 |
+
h11==0.14.0
|
20 |
+
httpcore==0.16.3
|
21 |
+
httpx==0.23.1
|
22 |
+
idna==3.4
|
23 |
+
imageio==2.23.0
|
24 |
+
imageio-ffmpeg==0.4.7
|
25 |
+
Jinja2==3.1.2
|
26 |
+
jsonschema==4.17.3
|
27 |
+
kiwisolver==1.4.4
|
28 |
+
linkify-it-py==1.0.3
|
29 |
+
markdown-it-py==2.1.0
|
30 |
+
MarkupSafe==2.1.1
|
31 |
+
matplotlib==3.6.2
|
32 |
+
mdit-py-plugins==0.3.3
|
33 |
+
mdurl==0.1.2
|
34 |
+
moviepy==1.0.3
|
35 |
+
multidict==6.0.4
|
36 |
+
networkx==2.8.8
|
37 |
+
numpy==1.24.1
|
38 |
+
opencv-python==4.7.0.68
|
39 |
+
orjson==3.8.3
|
40 |
+
packaging==22.0
|
41 |
+
pandas==1.5.2
|
42 |
+
Pillow==9.3.0
|
43 |
+
proglog==0.1.10
|
44 |
+
pycryptodome==3.16.0
|
45 |
+
pydantic==1.10.4
|
46 |
+
pydub==0.25.1
|
47 |
+
pyparsing==3.0.9
|
48 |
+
pyrsistent==0.19.3
|
49 |
+
python-dateutil==2.8.2
|
50 |
+
python-multipart==0.0.5
|
51 |
+
pytube==12.1.2
|
52 |
+
pytz==2022.7
|
53 |
+
PyWavelets==1.4.1
|
54 |
+
PyYAML==6.0
|
55 |
+
requests==2.28.1
|
56 |
+
rfc3986==1.5.0
|
57 |
+
scikit-image==0.19.3
|
58 |
+
scipy==1.9.3
|
59 |
+
six==1.16.0
|
60 |
+
sniffio==1.3.0
|
61 |
+
starlette==0.22.0
|
62 |
+
tifffile==2022.10.10
|
63 |
+
toolz==0.12.0
|
64 |
+
torch==1.13.1
|
65 |
+
torchvision==0.14.1
|
66 |
+
tqdm==4.64.1
|
67 |
+
typing_extensions==4.4.0
|
68 |
+
uc-micro-py==1.0.1
|
69 |
+
urllib3==1.26.13
|
70 |
+
uvicorn==0.20.0
|
71 |
+
websockets==10.4
|
72 |
+
yarl==1.8.2
|