TroglodyteDerivations
commited on
Create gan.py
Browse files
gan.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
import torchvision.transforms as transforms
|
5 |
+
import torch
|
6 |
+
import io
|
7 |
+
import os
|
8 |
+
import functools
|
9 |
+
|
10 |
+
class DataLoader():
|
11 |
+
|
12 |
+
def __init__(self, opt, cv_img):
|
13 |
+
super(DataLoader, self).__init__()
|
14 |
+
|
15 |
+
self.dataset = Dataset()
|
16 |
+
self.dataset.initialize(opt, cv_img)
|
17 |
+
|
18 |
+
self.dataloader = torch.utils.data.DataLoader(
|
19 |
+
self.dataset,
|
20 |
+
batch_size=opt.batchSize,
|
21 |
+
shuffle=not opt.serial_batches,
|
22 |
+
num_workers=int(opt.nThreads))
|
23 |
+
|
24 |
+
def load_data(self):
|
25 |
+
return self.dataloader
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
return 1
|
29 |
+
|
30 |
+
class Dataset(torch.utils.data.Dataset):
|
31 |
+
def __init__(self):
|
32 |
+
super(Dataset, self).__init__()
|
33 |
+
|
34 |
+
def initialize(self, opt, cv_img):
|
35 |
+
self.opt = opt
|
36 |
+
self.root = opt.dataroot
|
37 |
+
|
38 |
+
self.A = Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
|
39 |
+
self.dataset_size = 1
|
40 |
+
|
41 |
+
def __getitem__(self, index):
|
42 |
+
|
43 |
+
transform_A = get_transform(self.opt)
|
44 |
+
A_tensor = transform_A(self.A.convert('RGB'))
|
45 |
+
|
46 |
+
B_tensor = inst_tensor = feat_tensor = 0
|
47 |
+
|
48 |
+
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
|
49 |
+
'feat': feat_tensor, 'path': ""}
|
50 |
+
|
51 |
+
return input_dict
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return 1
|
55 |
+
|
56 |
+
class DeepModel(torch.nn.Module):
|
57 |
+
|
58 |
+
def initialize(self, opt):
|
59 |
+
|
60 |
+
torch.cuda.empty_cache()
|
61 |
+
|
62 |
+
self.opt = opt
|
63 |
+
|
64 |
+
self.gpu_ids = [] #FIX CPU
|
65 |
+
|
66 |
+
self.netG = self.__define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
|
67 |
+
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
|
68 |
+
opt.n_blocks_local, opt.norm, self.gpu_ids)
|
69 |
+
|
70 |
+
# load networks
|
71 |
+
self.__load_network(self.netG)
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
def inference(self, label, inst):
|
76 |
+
|
77 |
+
# Encode Inputs
|
78 |
+
input_label, inst_map, _, _ = self.__encode_input(label, inst, infer=True)
|
79 |
+
|
80 |
+
# Fake Generation
|
81 |
+
input_concat = input_label
|
82 |
+
|
83 |
+
with torch.no_grad():
|
84 |
+
fake_image = self.netG.forward(input_concat)
|
85 |
+
|
86 |
+
return fake_image
|
87 |
+
|
88 |
+
# helper loading function that can be used by subclasses
|
89 |
+
def __load_network(self, network):
|
90 |
+
|
91 |
+
save_path = os.path.join(self.opt.checkpoints_dir)
|
92 |
+
|
93 |
+
network.load_state_dict(torch.load(save_path))
|
94 |
+
|
95 |
+
def __encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
|
96 |
+
if (len(self.gpu_ids) > 0):
|
97 |
+
input_label = label_map.data.cuda() #GPU
|
98 |
+
else:
|
99 |
+
input_label = label_map.data #CPU
|
100 |
+
|
101 |
+
return input_label, inst_map, real_image, feat_map
|
102 |
+
|
103 |
+
def __weights_init(self, m):
|
104 |
+
classname = m.__class__.__name__
|
105 |
+
if classname.find('Conv') != -1:
|
106 |
+
m.weight.data.normal_(0.0, 0.02)
|
107 |
+
elif classname.find('BatchNorm2d') != -1:
|
108 |
+
m.weight.data.normal_(1.0, 0.02)
|
109 |
+
m.bias.data.fill_(0)
|
110 |
+
|
111 |
+
def __define_G(self, input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
|
112 |
+
n_blocks_local=3, norm='instance', gpu_ids=[]):
|
113 |
+
norm_layer = self.__get_norm_layer(norm_type=norm)
|
114 |
+
netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
|
115 |
+
|
116 |
+
if len(gpu_ids) > 0:
|
117 |
+
netG.cuda(gpu_ids[0])
|
118 |
+
netG.apply(self.__weights_init)
|
119 |
+
return netG
|
120 |
+
|
121 |
+
def __get_norm_layer(self, norm_type='instance'):
|
122 |
+
norm_layer = functools.partial(torch.nn.InstanceNorm2d, affine=False)
|
123 |
+
return norm_layer
|
124 |
+
|
125 |
+
##############################################################################
|
126 |
+
# Generator
|
127 |
+
##############################################################################
|
128 |
+
class GlobalGenerator(torch.nn.Module):
|
129 |
+
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=torch.nn.BatchNorm2d,
|
130 |
+
padding_type='reflect'):
|
131 |
+
assert(n_blocks >= 0)
|
132 |
+
super(GlobalGenerator, self).__init__()
|
133 |
+
activation = torch.nn.ReLU(True)
|
134 |
+
|
135 |
+
model = [torch.nn.ReflectionPad2d(3), torch.nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
|
136 |
+
### downsample
|
137 |
+
for i in range(n_downsampling):
|
138 |
+
mult = 2**i
|
139 |
+
model += [torch.nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
|
140 |
+
norm_layer(ngf * mult * 2), activation]
|
141 |
+
|
142 |
+
### resnet blocks
|
143 |
+
mult = 2**n_downsampling
|
144 |
+
for i in range(n_blocks):
|
145 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
|
146 |
+
|
147 |
+
### upsample
|
148 |
+
for i in range(n_downsampling):
|
149 |
+
mult = 2**(n_downsampling - i)
|
150 |
+
model += [torch.nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
|
151 |
+
norm_layer(int(ngf * mult / 2)), activation]
|
152 |
+
model += [torch.nn.ReflectionPad2d(3), torch.nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), torch.nn.Tanh()]
|
153 |
+
self.model = torch.nn.Sequential(*model)
|
154 |
+
|
155 |
+
def forward(self, input):
|
156 |
+
return self.model(input)
|
157 |
+
|
158 |
+
# Define a resnet block
|
159 |
+
class ResnetBlock(torch.nn.Module):
|
160 |
+
def __init__(self, dim, padding_type, norm_layer, activation=torch.nn.ReLU(True), use_dropout=False):
|
161 |
+
super(ResnetBlock, self).__init__()
|
162 |
+
self.conv_block = self.__build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
|
163 |
+
|
164 |
+
def __build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
|
165 |
+
conv_block = []
|
166 |
+
p = 0
|
167 |
+
if padding_type == 'reflect':
|
168 |
+
conv_block += [torch.nn.ReflectionPad2d(1)]
|
169 |
+
elif padding_type == 'replicate':
|
170 |
+
conv_block += [torch.nn.ReplicationPad2d(1)]
|
171 |
+
elif padding_type == 'zero':
|
172 |
+
p = 1
|
173 |
+
else:
|
174 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
175 |
+
|
176 |
+
conv_block += [torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
177 |
+
norm_layer(dim),
|
178 |
+
activation]
|
179 |
+
if use_dropout:
|
180 |
+
conv_block += [torch.nn.Dropout(0.5)]
|
181 |
+
|
182 |
+
p = 0
|
183 |
+
if padding_type == 'reflect':
|
184 |
+
conv_block += [torch.nn.ReflectionPad2d(1)]
|
185 |
+
elif padding_type == 'replicate':
|
186 |
+
conv_block += [torch.nn.ReplicationPad2d(1)]
|
187 |
+
elif padding_type == 'zero':
|
188 |
+
p = 1
|
189 |
+
else:
|
190 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
191 |
+
conv_block += [torch.nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
192 |
+
norm_layer(dim)]
|
193 |
+
|
194 |
+
return torch.nn.Sequential(*conv_block)
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
out = x + self.conv_block(x)
|
198 |
+
return out
|
199 |
+
|
200 |
+
# Data utils:
|
201 |
+
def get_transform(opt, method=Image.BICUBIC, normalize=True):
|
202 |
+
transform_list = []
|
203 |
+
|
204 |
+
base = float(2 ** opt.n_downsample_global)
|
205 |
+
if opt.netG == 'local':
|
206 |
+
base *= (2 ** opt.n_local_enhancers)
|
207 |
+
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
|
208 |
+
|
209 |
+
transform_list += [transforms.ToTensor()]
|
210 |
+
|
211 |
+
if normalize:
|
212 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
|
213 |
+
(0.5, 0.5, 0.5))]
|
214 |
+
return transforms.Compose(transform_list)
|
215 |
+
|
216 |
+
def __make_power_2(img, base, method=Image.BICUBIC):
|
217 |
+
ow, oh = img.size
|
218 |
+
h = int(round(oh / base) * base)
|
219 |
+
w = int(round(ow / base) * base)
|
220 |
+
if (h == oh) and (w == ow):
|
221 |
+
return img
|
222 |
+
return img.resize((w, h), method)
|
223 |
+
|
224 |
+
# Converts a Tensor into a Numpy array
|
225 |
+
# |imtype|: the desired type of the converted numpy array
|
226 |
+
def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
|
227 |
+
if isinstance(image_tensor, list):
|
228 |
+
image_numpy = []
|
229 |
+
for i in range(len(image_tensor)):
|
230 |
+
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
|
231 |
+
return image_numpy
|
232 |
+
image_numpy = image_tensor.cpu().float().numpy()
|
233 |
+
if normalize:
|
234 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
|
235 |
+
else:
|
236 |
+
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
|
237 |
+
image_numpy = np.clip(image_numpy, 0, 255)
|
238 |
+
if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
|
239 |
+
image_numpy = image_numpy[:,:,0]
|
240 |
+
return image_numpy.astype(imtype)
|